Skip to main content

ember_persistence/
snapshot.rs

1//! Point-in-time snapshot files.
2//!
3//! Each shard writes its own snapshot (`shard-{id}.snap`). The format
4//! stores all live entries in a single pass. Writes go to a `.tmp`
5//! file first and are atomically renamed on completion — this ensures
6//! a partial/crashed snapshot never corrupts the existing `.snap` file.
7//!
8//! File layout:
9//! ```text
10//! [ESNP magic: 4B][version: 1B][shard_id: 2B][entry_count: 4B]
11//! [entries...]
12//! [footer_crc32: 4B]
13//! ```
14//!
15//! Each entry (v2, type-tagged):
16//! ```text
17//! [key_len: 4B][key][type_tag: 1B][type-specific payload][expire_ms: 8B]
18//! ```
19//!
20//! Type tags: 0=string, 1=list, 2=sorted set.
21//! `expire_ms` is the remaining TTL in milliseconds, or -1 for no expiry.
22//! v1 entries (no type tag) are still readable for backward compatibility.
23
24use std::collections::VecDeque;
25use std::fs::{self, File};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33/// Type tags for snapshot entries.
34const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37
38/// The value stored in a snapshot entry.
39#[derive(Debug, Clone, PartialEq)]
40pub enum SnapValue {
41    /// A string value.
42    String(Bytes),
43    /// A list of values.
44    List(VecDeque<Bytes>),
45    /// A sorted set: vec of (score, member) pairs.
46    SortedSet(Vec<(f64, String)>),
47}
48
49/// A single entry in a snapshot file.
50#[derive(Debug, Clone, PartialEq)]
51pub struct SnapEntry {
52    pub key: String,
53    pub value: SnapValue,
54    /// Remaining TTL in milliseconds, or -1 for no expiration.
55    pub expire_ms: i64,
56}
57
58/// Writes a complete snapshot to disk.
59///
60/// Entries are written to a temporary file first, then atomically
61/// renamed to the final path. The caller provides an iterator over
62/// the entries to write.
63pub struct SnapshotWriter {
64    final_path: PathBuf,
65    tmp_path: PathBuf,
66    writer: BufWriter<File>,
67    /// Running CRC over all entry bytes for the footer checksum.
68    hasher: crc32fast::Hasher,
69    count: u32,
70    /// Set to `true` after a successful `finish()`. Used by the `Drop`
71    /// impl to clean up incomplete temp files.
72    finished: bool,
73}
74
75impl SnapshotWriter {
76    /// Creates a new snapshot writer. The file won't appear at `path`
77    /// until [`finish`] is called successfully.
78    pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
79        let final_path = path.into();
80        let tmp_path = final_path.with_extension("snap.tmp");
81
82        let file = File::create(&tmp_path)?;
83        let mut writer = BufWriter::new(file);
84
85        // write header: magic + version + shard_id + placeholder entry count
86        format::write_header(&mut writer, format::SNAP_MAGIC)?;
87        format::write_u16(&mut writer, shard_id)?;
88        // entry count — we'll seek back and update, or just write it now
89        // and track. since we're streaming, write 0 and update after.
90        format::write_u32(&mut writer, 0)?;
91
92        Ok(Self {
93            final_path,
94            tmp_path,
95            writer,
96            hasher: crc32fast::Hasher::new(),
97            count: 0,
98            finished: false,
99        })
100    }
101
102    /// Writes a single entry to the snapshot.
103    pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
104        let mut buf = Vec::new();
105        format::write_bytes(&mut buf, entry.key.as_bytes())?;
106        match &entry.value {
107            SnapValue::String(data) => {
108                format::write_u8(&mut buf, TYPE_STRING)?;
109                format::write_bytes(&mut buf, data)?;
110            }
111            SnapValue::List(deque) => {
112                format::write_u8(&mut buf, TYPE_LIST)?;
113                format::write_u32(&mut buf, deque.len() as u32)?;
114                for item in deque {
115                    format::write_bytes(&mut buf, item)?;
116                }
117            }
118            SnapValue::SortedSet(members) => {
119                format::write_u8(&mut buf, TYPE_SORTED_SET)?;
120                format::write_u32(&mut buf, members.len() as u32)?;
121                for (score, member) in members {
122                    format::write_f64(&mut buf, *score)?;
123                    format::write_bytes(&mut buf, member.as_bytes())?;
124                }
125            }
126        }
127        format::write_i64(&mut buf, entry.expire_ms)?;
128
129        self.hasher.update(&buf);
130        self.writer.write_all(&buf)?;
131        self.count += 1;
132        Ok(())
133    }
134
135    /// Finalizes the snapshot: writes the footer CRC, flushes, and
136    /// atomically renames the temp file to the final path.
137    pub fn finish(mut self) -> Result<(), FormatError> {
138        // write footer CRC — clone the hasher so we don't move out of self
139        let checksum = self.hasher.clone().finalize();
140        format::write_u32(&mut self.writer, checksum)?;
141        self.writer.flush()?;
142        self.writer.get_ref().sync_all()?;
143
144        // rewrite the header with the correct entry count.
145        // open a second handle for the seek — the BufWriter is already
146        // flushed and synced above.
147        {
148            use std::io::{Seek, SeekFrom};
149            let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
150            // header: 4 (magic) + 1 (version) + 2 (shard_id) = 7 bytes
151            file.seek(SeekFrom::Start(7))?;
152            format::write_u32(&mut file, self.count)?;
153            file.sync_all()?;
154        }
155
156        // atomic rename
157        fs::rename(&self.tmp_path, &self.final_path)?;
158        self.finished = true;
159        Ok(())
160    }
161}
162
163impl Drop for SnapshotWriter {
164    fn drop(&mut self) {
165        if !self.finished {
166            // best-effort cleanup of incomplete temp file
167            let _ = fs::remove_file(&self.tmp_path);
168        }
169    }
170}
171
172/// Reads entries from a snapshot file.
173pub struct SnapshotReader {
174    reader: BufReader<File>,
175    pub shard_id: u16,
176    pub entry_count: u32,
177    read_so_far: u32,
178    hasher: crc32fast::Hasher,
179    /// Format version — v1 has no type tags, v2 has type-tagged entries.
180    version: u8,
181}
182
183impl SnapshotReader {
184    /// Opens a snapshot file and reads the header.
185    pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
186        let file = File::open(path.as_ref())?;
187        let mut reader = BufReader::new(file);
188
189        let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
190        let shard_id = format::read_u16(&mut reader)?;
191        let entry_count = format::read_u32(&mut reader)?;
192
193        Ok(Self {
194            reader,
195            shard_id,
196            entry_count,
197            read_so_far: 0,
198            hasher: crc32fast::Hasher::new(),
199            version,
200        })
201    }
202
203    /// Reads the next entry. Returns `None` when all entries have been read.
204    pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
205        if self.read_so_far >= self.entry_count {
206            return Ok(None);
207        }
208
209        let mut buf = Vec::new();
210
211        let key_bytes = format::read_bytes(&mut self.reader)?;
212        format::write_bytes(&mut buf, &key_bytes).expect("vec write");
213
214        let value = if self.version == 1 {
215            // v1: no type tag, value is always a string
216            let value_bytes = format::read_bytes(&mut self.reader)?;
217            format::write_bytes(&mut buf, &value_bytes).expect("vec write");
218            SnapValue::String(Bytes::from(value_bytes))
219        } else {
220            // v2+: type-tagged values
221            let type_tag = format::read_u8(&mut self.reader)?;
222            format::write_u8(&mut buf, type_tag).expect("vec write");
223            match type_tag {
224                TYPE_STRING => {
225                    let value_bytes = format::read_bytes(&mut self.reader)?;
226                    format::write_bytes(&mut buf, &value_bytes).expect("vec write");
227                    SnapValue::String(Bytes::from(value_bytes))
228                }
229                TYPE_LIST => {
230                    let count = format::read_u32(&mut self.reader)?;
231                    format::write_u32(&mut buf, count).expect("vec write");
232                    let mut deque = VecDeque::with_capacity(count as usize);
233                    for _ in 0..count {
234                        let item = format::read_bytes(&mut self.reader)?;
235                        format::write_bytes(&mut buf, &item).expect("vec write");
236                        deque.push_back(Bytes::from(item));
237                    }
238                    SnapValue::List(deque)
239                }
240                TYPE_SORTED_SET => {
241                    let count = format::read_u32(&mut self.reader)?;
242                    format::write_u32(&mut buf, count).expect("vec write");
243                    let mut members = Vec::with_capacity(count as usize);
244                    for _ in 0..count {
245                        let score = format::read_f64(&mut self.reader)?;
246                        format::write_f64(&mut buf, score).expect("vec write");
247                        let member_bytes = format::read_bytes(&mut self.reader)?;
248                        format::write_bytes(&mut buf, &member_bytes).expect("vec write");
249                        let member = String::from_utf8(member_bytes).map_err(|_| {
250                            FormatError::Io(io::Error::new(
251                                io::ErrorKind::InvalidData,
252                                "member is not valid utf-8",
253                            ))
254                        })?;
255                        members.push((score, member));
256                    }
257                    SnapValue::SortedSet(members)
258                }
259                _ => {
260                    return Err(FormatError::UnknownTag(type_tag));
261                }
262            }
263        };
264
265        let expire_ms = format::read_i64(&mut self.reader)?;
266        format::write_i64(&mut buf, expire_ms).expect("vec write");
267        self.hasher.update(&buf);
268
269        let key = String::from_utf8(key_bytes).map_err(|_| {
270            FormatError::Io(io::Error::new(
271                io::ErrorKind::InvalidData,
272                "key is not valid utf-8",
273            ))
274        })?;
275
276        self.read_so_far += 1;
277        Ok(Some(SnapEntry {
278            key,
279            value,
280            expire_ms,
281        }))
282    }
283
284    /// Verifies the footer CRC32 after all entries have been read.
285    /// Must be called after reading all entries.
286    pub fn verify_footer(self) -> Result<(), FormatError> {
287        let expected = self.hasher.finalize();
288        let mut reader = self.reader;
289        let stored = format::read_u32(&mut reader)?;
290        format::verify_crc32_values(expected, stored)
291    }
292}
293
294/// Returns the snapshot file path for a given shard in a data directory.
295pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
296    data_dir.join(format!("shard-{shard_id}.snap"))
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    fn temp_dir() -> tempfile::TempDir {
304        tempfile::tempdir().expect("create temp dir")
305    }
306
307    #[test]
308    fn empty_snapshot_round_trip() {
309        let dir = temp_dir();
310        let path = dir.path().join("empty.snap");
311
312        {
313            let writer = SnapshotWriter::create(&path, 0).unwrap();
314            writer.finish().unwrap();
315        }
316
317        let reader = SnapshotReader::open(&path).unwrap();
318        assert_eq!(reader.shard_id, 0);
319        assert_eq!(reader.entry_count, 0);
320        reader.verify_footer().unwrap();
321    }
322
323    #[test]
324    fn entries_round_trip() {
325        let dir = temp_dir();
326        let path = dir.path().join("data.snap");
327
328        let entries = vec![
329            SnapEntry {
330                key: "hello".into(),
331                value: SnapValue::String(Bytes::from("world")),
332                expire_ms: -1,
333            },
334            SnapEntry {
335                key: "ttl".into(),
336                value: SnapValue::String(Bytes::from("expiring")),
337                expire_ms: 5000,
338            },
339            SnapEntry {
340                key: "empty".into(),
341                value: SnapValue::String(Bytes::new()),
342                expire_ms: -1,
343            },
344        ];
345
346        {
347            let mut writer = SnapshotWriter::create(&path, 7).unwrap();
348            for entry in &entries {
349                writer.write_entry(entry).unwrap();
350            }
351            writer.finish().unwrap();
352        }
353
354        let mut reader = SnapshotReader::open(&path).unwrap();
355        assert_eq!(reader.shard_id, 7);
356        assert_eq!(reader.entry_count, 3);
357
358        let mut got = Vec::new();
359        while let Some(entry) = reader.read_entry().unwrap() {
360            got.push(entry);
361        }
362        assert_eq!(entries, got);
363        reader.verify_footer().unwrap();
364    }
365
366    #[test]
367    fn corrupt_footer_detected() {
368        let dir = temp_dir();
369        let path = dir.path().join("corrupt.snap");
370
371        {
372            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
373            writer
374                .write_entry(&SnapEntry {
375                    key: "k".into(),
376                    value: SnapValue::String(Bytes::from("v")),
377                    expire_ms: -1,
378                })
379                .unwrap();
380            writer.finish().unwrap();
381        }
382
383        // corrupt the last byte (footer CRC)
384        let mut data = fs::read(&path).unwrap();
385        let last = data.len() - 1;
386        data[last] ^= 0xFF;
387        fs::write(&path, &data).unwrap();
388
389        let mut reader = SnapshotReader::open(&path).unwrap();
390        // reading entries should still work
391        reader.read_entry().unwrap();
392        // but footer verification should fail
393        let err = reader.verify_footer().unwrap_err();
394        assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
395    }
396
397    #[test]
398    fn atomic_rename_prevents_partial_snapshots() {
399        let dir = temp_dir();
400        let path = dir.path().join("atomic.snap");
401
402        // write an initial snapshot
403        {
404            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
405            writer
406                .write_entry(&SnapEntry {
407                    key: "original".into(),
408                    value: SnapValue::String(Bytes::from("data")),
409                    expire_ms: -1,
410                })
411                .unwrap();
412            writer.finish().unwrap();
413        }
414
415        // start a second snapshot but don't finish it
416        {
417            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
418            writer
419                .write_entry(&SnapEntry {
420                    key: "new".into(),
421                    value: SnapValue::String(Bytes::from("partial")),
422                    expire_ms: -1,
423                })
424                .unwrap();
425            // drop without finish — simulates a crash
426            drop(writer);
427        }
428
429        // the original snapshot should still be intact
430        let mut reader = SnapshotReader::open(&path).unwrap();
431        let entry = reader.read_entry().unwrap().unwrap();
432        assert_eq!(entry.key, "original");
433
434        // the tmp file should have been cleaned up by Drop
435        let tmp = path.with_extension("snap.tmp");
436        assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
437    }
438
439    #[test]
440    fn ttl_entries_preserved() {
441        let dir = temp_dir();
442        let path = dir.path().join("ttl.snap");
443
444        let entry = SnapEntry {
445            key: "expires".into(),
446            value: SnapValue::String(Bytes::from("soon")),
447            expire_ms: 42_000,
448        };
449
450        {
451            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
452            writer.write_entry(&entry).unwrap();
453            writer.finish().unwrap();
454        }
455
456        let mut reader = SnapshotReader::open(&path).unwrap();
457        let got = reader.read_entry().unwrap().unwrap();
458        assert_eq!(got.expire_ms, 42_000);
459        reader.verify_footer().unwrap();
460    }
461
462    #[test]
463    fn list_entries_round_trip() {
464        let dir = temp_dir();
465        let path = dir.path().join("list.snap");
466
467        let mut deque = VecDeque::new();
468        deque.push_back(Bytes::from("a"));
469        deque.push_back(Bytes::from("b"));
470        deque.push_back(Bytes::from("c"));
471
472        let entries = vec![
473            SnapEntry {
474                key: "mylist".into(),
475                value: SnapValue::List(deque),
476                expire_ms: -1,
477            },
478            SnapEntry {
479                key: "mystr".into(),
480                value: SnapValue::String(Bytes::from("val")),
481                expire_ms: 1000,
482            },
483        ];
484
485        {
486            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
487            for entry in &entries {
488                writer.write_entry(entry).unwrap();
489            }
490            writer.finish().unwrap();
491        }
492
493        let mut reader = SnapshotReader::open(&path).unwrap();
494        let mut got = Vec::new();
495        while let Some(entry) = reader.read_entry().unwrap() {
496            got.push(entry);
497        }
498        assert_eq!(entries, got);
499        reader.verify_footer().unwrap();
500    }
501
502    #[test]
503    fn sorted_set_entries_round_trip() {
504        let dir = temp_dir();
505        let path = dir.path().join("zset.snap");
506
507        let entries = vec![
508            SnapEntry {
509                key: "board".into(),
510                value: SnapValue::SortedSet(vec![
511                    (100.0, "alice".into()),
512                    (200.0, "bob".into()),
513                    (150.0, "charlie".into()),
514                ]),
515                expire_ms: -1,
516            },
517            SnapEntry {
518                key: "mystr".into(),
519                value: SnapValue::String(Bytes::from("val")),
520                expire_ms: 1000,
521            },
522        ];
523
524        {
525            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
526            for entry in &entries {
527                writer.write_entry(entry).unwrap();
528            }
529            writer.finish().unwrap();
530        }
531
532        let mut reader = SnapshotReader::open(&path).unwrap();
533        let mut got = Vec::new();
534        while let Some(entry) = reader.read_entry().unwrap() {
535            got.push(entry);
536        }
537        assert_eq!(entries, got);
538        reader.verify_footer().unwrap();
539    }
540
541    #[test]
542    fn snapshot_path_format() {
543        let p = snapshot_path(Path::new("/data"), 5);
544        assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
545    }
546}