Skip to main content

ember_persistence/
snapshot.rs

1//! Point-in-time snapshot files.
2//!
3//! Each shard writes its own snapshot (`shard-{id}.snap`). The format
4//! stores all live entries in a single pass. Writes go to a `.tmp`
5//! file first and are atomically renamed on completion — this ensures
6//! a partial/crashed snapshot never corrupts the existing `.snap` file.
7//!
8//! File layout:
9//! ```text
10//! [ESNP magic: 4B][version: 1B][shard_id: 2B][entry_count: 4B]
11//! [entries...]
12//! [footer_crc32: 4B]
13//! ```
14//!
15//! Each entry (v2, type-tagged):
16//! ```text
17//! [key_len: 4B][key][type_tag: 1B][type-specific payload][expire_ms: 8B]
18//! ```
19//!
20//! Type tags: 0=string, 1=list, 2=sorted set.
21//! `expire_ms` is the remaining TTL in milliseconds, or -1 for no expiry.
22//! v1 entries (no type tag) are still readable for backward compatibility.
23
24use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File, OpenOptions};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33/// Type tags for snapshot entries.
34const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39
40/// The value stored in a snapshot entry.
41#[derive(Debug, Clone, PartialEq)]
42pub enum SnapValue {
43    /// A string value.
44    String(Bytes),
45    /// A list of values.
46    List(VecDeque<Bytes>),
47    /// A sorted set: vec of (score, member) pairs.
48    SortedSet(Vec<(f64, String)>),
49    /// A hash: map of field names to values.
50    Hash(HashMap<String, Bytes>),
51    /// An unordered set of unique string members.
52    Set(HashSet<String>),
53}
54
55/// A single entry in a snapshot file.
56#[derive(Debug, Clone, PartialEq)]
57pub struct SnapEntry {
58    pub key: String,
59    pub value: SnapValue,
60    /// Remaining TTL in milliseconds, or -1 for no expiration.
61    pub expire_ms: i64,
62}
63
64/// Writes a complete snapshot to disk.
65///
66/// Entries are written to a temporary file first, then atomically
67/// renamed to the final path. The caller provides an iterator over
68/// the entries to write.
69pub struct SnapshotWriter {
70    final_path: PathBuf,
71    tmp_path: PathBuf,
72    writer: BufWriter<File>,
73    /// Running CRC over all entry bytes for the footer checksum.
74    hasher: crc32fast::Hasher,
75    count: u32,
76    /// Set to `true` after a successful `finish()`. Used by the `Drop`
77    /// impl to clean up incomplete temp files.
78    finished: bool,
79}
80
81impl SnapshotWriter {
82    /// Creates a new snapshot writer. The file won't appear at `path`
83    /// until [`Self::finish`] is called successfully.
84    pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
85        let final_path = path.into();
86        let tmp_path = final_path.with_extension("snap.tmp");
87
88        let file = {
89            let mut opts = OpenOptions::new();
90            opts.write(true).create(true).truncate(true);
91            #[cfg(unix)]
92            {
93                use std::os::unix::fs::OpenOptionsExt;
94                opts.mode(0o600);
95            }
96            opts.open(&tmp_path)?
97        };
98        let mut writer = BufWriter::new(file);
99
100        // write header: magic + version + shard_id + placeholder entry count
101        format::write_header(&mut writer, format::SNAP_MAGIC)?;
102        format::write_u16(&mut writer, shard_id)?;
103        // entry count — we'll seek back and update, or just write it now
104        // and track. since we're streaming, write 0 and update after.
105        format::write_u32(&mut writer, 0)?;
106
107        Ok(Self {
108            final_path,
109            tmp_path,
110            writer,
111            hasher: crc32fast::Hasher::new(),
112            count: 0,
113            finished: false,
114        })
115    }
116
117    /// Writes a single entry to the snapshot.
118    pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
119        let mut buf = Vec::new();
120        format::write_bytes(&mut buf, entry.key.as_bytes())?;
121        match &entry.value {
122            SnapValue::String(data) => {
123                format::write_u8(&mut buf, TYPE_STRING)?;
124                format::write_bytes(&mut buf, data)?;
125            }
126            SnapValue::List(deque) => {
127                format::write_u8(&mut buf, TYPE_LIST)?;
128                format::write_u32(&mut buf, deque.len() as u32)?;
129                for item in deque {
130                    format::write_bytes(&mut buf, item)?;
131                }
132            }
133            SnapValue::SortedSet(members) => {
134                format::write_u8(&mut buf, TYPE_SORTED_SET)?;
135                format::write_u32(&mut buf, members.len() as u32)?;
136                for (score, member) in members {
137                    format::write_f64(&mut buf, *score)?;
138                    format::write_bytes(&mut buf, member.as_bytes())?;
139                }
140            }
141            SnapValue::Hash(map) => {
142                format::write_u8(&mut buf, TYPE_HASH)?;
143                format::write_u32(&mut buf, map.len() as u32)?;
144                for (field, value) in map {
145                    format::write_bytes(&mut buf, field.as_bytes())?;
146                    format::write_bytes(&mut buf, value)?;
147                }
148            }
149            SnapValue::Set(set) => {
150                format::write_u8(&mut buf, TYPE_SET)?;
151                format::write_u32(&mut buf, set.len() as u32)?;
152                for member in set {
153                    format::write_bytes(&mut buf, member.as_bytes())?;
154                }
155            }
156        }
157        format::write_i64(&mut buf, entry.expire_ms)?;
158
159        self.hasher.update(&buf);
160        self.writer.write_all(&buf)?;
161        self.count += 1;
162        Ok(())
163    }
164
165    /// Finalizes the snapshot: writes the footer CRC, flushes, and
166    /// atomically renames the temp file to the final path.
167    pub fn finish(mut self) -> Result<(), FormatError> {
168        // write footer CRC — clone the hasher so we don't move out of self
169        let checksum = self.hasher.clone().finalize();
170        format::write_u32(&mut self.writer, checksum)?;
171        self.writer.flush()?;
172        self.writer.get_ref().sync_all()?;
173
174        // rewrite the header with the correct entry count.
175        // open a second handle for the seek — the BufWriter is already
176        // flushed and synced above.
177        {
178            use std::io::{Seek, SeekFrom};
179            let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
180            // header: 4 (magic) + 1 (version) + 2 (shard_id) = 7 bytes
181            file.seek(SeekFrom::Start(7))?;
182            format::write_u32(&mut file, self.count)?;
183            file.sync_all()?;
184        }
185
186        // atomic rename
187        fs::rename(&self.tmp_path, &self.final_path)?;
188        self.finished = true;
189        Ok(())
190    }
191}
192
193impl Drop for SnapshotWriter {
194    fn drop(&mut self) {
195        if !self.finished {
196            // best-effort cleanup of incomplete temp file
197            let _ = fs::remove_file(&self.tmp_path);
198        }
199    }
200}
201
202/// Reads entries from a snapshot file.
203pub struct SnapshotReader {
204    reader: BufReader<File>,
205    pub shard_id: u16,
206    pub entry_count: u32,
207    read_so_far: u32,
208    hasher: crc32fast::Hasher,
209    /// Format version — v1 has no type tags, v2 has type-tagged entries.
210    version: u8,
211}
212
213impl SnapshotReader {
214    /// Opens a snapshot file and reads the header.
215    pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
216        let file = File::open(path.as_ref())?;
217        let mut reader = BufReader::new(file);
218
219        let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
220        let shard_id = format::read_u16(&mut reader)?;
221        let entry_count = format::read_u32(&mut reader)?;
222
223        Ok(Self {
224            reader,
225            shard_id,
226            entry_count,
227            read_so_far: 0,
228            hasher: crc32fast::Hasher::new(),
229            version,
230        })
231    }
232
233    /// Reads the next entry. Returns `None` when all entries have been read.
234    pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
235        if self.read_so_far >= self.entry_count {
236            return Ok(None);
237        }
238
239        let mut buf = Vec::new();
240
241        let key_bytes = format::read_bytes(&mut self.reader)?;
242        format::write_bytes(&mut buf, &key_bytes)?;
243
244        let value = if self.version == 1 {
245            // v1: no type tag, value is always a string
246            let value_bytes = format::read_bytes(&mut self.reader)?;
247            format::write_bytes(&mut buf, &value_bytes)?;
248            SnapValue::String(Bytes::from(value_bytes))
249        } else {
250            // v2+: type-tagged values
251            let type_tag = format::read_u8(&mut self.reader)?;
252            format::write_u8(&mut buf, type_tag)?;
253            match type_tag {
254                TYPE_STRING => {
255                    let value_bytes = format::read_bytes(&mut self.reader)?;
256                    format::write_bytes(&mut buf, &value_bytes)?;
257                    SnapValue::String(Bytes::from(value_bytes))
258                }
259                TYPE_LIST => {
260                    let count = format::read_u32(&mut self.reader)?;
261                    format::write_u32(&mut buf, count)?;
262                    let mut deque = VecDeque::with_capacity(count as usize);
263                    for _ in 0..count {
264                        let item = format::read_bytes(&mut self.reader)?;
265                        format::write_bytes(&mut buf, &item)?;
266                        deque.push_back(Bytes::from(item));
267                    }
268                    SnapValue::List(deque)
269                }
270                TYPE_SORTED_SET => {
271                    let count = format::read_u32(&mut self.reader)?;
272                    format::write_u32(&mut buf, count)?;
273                    let mut members = Vec::with_capacity(count as usize);
274                    for _ in 0..count {
275                        let score = format::read_f64(&mut self.reader)?;
276                        format::write_f64(&mut buf, score)?;
277                        let member_bytes = format::read_bytes(&mut self.reader)?;
278                        format::write_bytes(&mut buf, &member_bytes)?;
279                        let member = String::from_utf8(member_bytes).map_err(|_| {
280                            FormatError::Io(io::Error::new(
281                                io::ErrorKind::InvalidData,
282                                "member is not valid utf-8",
283                            ))
284                        })?;
285                        members.push((score, member));
286                    }
287                    SnapValue::SortedSet(members)
288                }
289                TYPE_HASH => {
290                    let count = format::read_u32(&mut self.reader)?;
291                    format::write_u32(&mut buf, count)?;
292                    let mut map = HashMap::with_capacity(count as usize);
293                    for _ in 0..count {
294                        let field_bytes = format::read_bytes(&mut self.reader)?;
295                        format::write_bytes(&mut buf, &field_bytes)?;
296                        let field = String::from_utf8(field_bytes).map_err(|_| {
297                            FormatError::Io(io::Error::new(
298                                io::ErrorKind::InvalidData,
299                                "hash field is not valid utf-8",
300                            ))
301                        })?;
302                        let value_bytes = format::read_bytes(&mut self.reader)?;
303                        format::write_bytes(&mut buf, &value_bytes)?;
304                        map.insert(field, Bytes::from(value_bytes));
305                    }
306                    SnapValue::Hash(map)
307                }
308                TYPE_SET => {
309                    let count = format::read_u32(&mut self.reader)?;
310                    format::write_u32(&mut buf, count)?;
311                    let mut set = HashSet::with_capacity(count as usize);
312                    for _ in 0..count {
313                        let member_bytes = format::read_bytes(&mut self.reader)?;
314                        format::write_bytes(&mut buf, &member_bytes)?;
315                        let member = String::from_utf8(member_bytes).map_err(|_| {
316                            FormatError::Io(io::Error::new(
317                                io::ErrorKind::InvalidData,
318                                "set member is not valid utf-8",
319                            ))
320                        })?;
321                        set.insert(member);
322                    }
323                    SnapValue::Set(set)
324                }
325                _ => {
326                    return Err(FormatError::UnknownTag(type_tag));
327                }
328            }
329        };
330
331        let expire_ms = format::read_i64(&mut self.reader)?;
332        format::write_i64(&mut buf, expire_ms)?;
333        self.hasher.update(&buf);
334
335        let key = String::from_utf8(key_bytes).map_err(|_| {
336            FormatError::Io(io::Error::new(
337                io::ErrorKind::InvalidData,
338                "key is not valid utf-8",
339            ))
340        })?;
341
342        self.read_so_far += 1;
343        Ok(Some(SnapEntry {
344            key,
345            value,
346            expire_ms,
347        }))
348    }
349
350    /// Verifies the footer CRC32 after all entries have been read.
351    /// Must be called after reading all entries.
352    pub fn verify_footer(self) -> Result<(), FormatError> {
353        let expected = self.hasher.finalize();
354        let mut reader = self.reader;
355        let stored = format::read_u32(&mut reader)?;
356        format::verify_crc32_values(expected, stored)
357    }
358}
359
360/// Returns the snapshot file path for a given shard in a data directory.
361pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
362    data_dir.join(format!("shard-{shard_id}.snap"))
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    fn temp_dir() -> tempfile::TempDir {
370        tempfile::tempdir().expect("create temp dir")
371    }
372
373    #[test]
374    fn empty_snapshot_round_trip() {
375        let dir = temp_dir();
376        let path = dir.path().join("empty.snap");
377
378        {
379            let writer = SnapshotWriter::create(&path, 0).unwrap();
380            writer.finish().unwrap();
381        }
382
383        let reader = SnapshotReader::open(&path).unwrap();
384        assert_eq!(reader.shard_id, 0);
385        assert_eq!(reader.entry_count, 0);
386        reader.verify_footer().unwrap();
387    }
388
389    #[test]
390    fn entries_round_trip() {
391        let dir = temp_dir();
392        let path = dir.path().join("data.snap");
393
394        let entries = vec![
395            SnapEntry {
396                key: "hello".into(),
397                value: SnapValue::String(Bytes::from("world")),
398                expire_ms: -1,
399            },
400            SnapEntry {
401                key: "ttl".into(),
402                value: SnapValue::String(Bytes::from("expiring")),
403                expire_ms: 5000,
404            },
405            SnapEntry {
406                key: "empty".into(),
407                value: SnapValue::String(Bytes::new()),
408                expire_ms: -1,
409            },
410        ];
411
412        {
413            let mut writer = SnapshotWriter::create(&path, 7).unwrap();
414            for entry in &entries {
415                writer.write_entry(entry).unwrap();
416            }
417            writer.finish().unwrap();
418        }
419
420        let mut reader = SnapshotReader::open(&path).unwrap();
421        assert_eq!(reader.shard_id, 7);
422        assert_eq!(reader.entry_count, 3);
423
424        let mut got = Vec::new();
425        while let Some(entry) = reader.read_entry().unwrap() {
426            got.push(entry);
427        }
428        assert_eq!(entries, got);
429        reader.verify_footer().unwrap();
430    }
431
432    #[test]
433    fn corrupt_footer_detected() {
434        let dir = temp_dir();
435        let path = dir.path().join("corrupt.snap");
436
437        {
438            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
439            writer
440                .write_entry(&SnapEntry {
441                    key: "k".into(),
442                    value: SnapValue::String(Bytes::from("v")),
443                    expire_ms: -1,
444                })
445                .unwrap();
446            writer.finish().unwrap();
447        }
448
449        // corrupt the last byte (footer CRC)
450        let mut data = fs::read(&path).unwrap();
451        let last = data.len() - 1;
452        data[last] ^= 0xFF;
453        fs::write(&path, &data).unwrap();
454
455        let mut reader = SnapshotReader::open(&path).unwrap();
456        // reading entries should still work
457        reader.read_entry().unwrap();
458        // but footer verification should fail
459        let err = reader.verify_footer().unwrap_err();
460        assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
461    }
462
463    #[test]
464    fn atomic_rename_prevents_partial_snapshots() {
465        let dir = temp_dir();
466        let path = dir.path().join("atomic.snap");
467
468        // write an initial snapshot
469        {
470            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
471            writer
472                .write_entry(&SnapEntry {
473                    key: "original".into(),
474                    value: SnapValue::String(Bytes::from("data")),
475                    expire_ms: -1,
476                })
477                .unwrap();
478            writer.finish().unwrap();
479        }
480
481        // start a second snapshot but don't finish it
482        {
483            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
484            writer
485                .write_entry(&SnapEntry {
486                    key: "new".into(),
487                    value: SnapValue::String(Bytes::from("partial")),
488                    expire_ms: -1,
489                })
490                .unwrap();
491            // drop without finish — simulates a crash
492            drop(writer);
493        }
494
495        // the original snapshot should still be intact
496        let mut reader = SnapshotReader::open(&path).unwrap();
497        let entry = reader.read_entry().unwrap().unwrap();
498        assert_eq!(entry.key, "original");
499
500        // the tmp file should have been cleaned up by Drop
501        let tmp = path.with_extension("snap.tmp");
502        assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
503    }
504
505    #[test]
506    fn ttl_entries_preserved() {
507        let dir = temp_dir();
508        let path = dir.path().join("ttl.snap");
509
510        let entry = SnapEntry {
511            key: "expires".into(),
512            value: SnapValue::String(Bytes::from("soon")),
513            expire_ms: 42_000,
514        };
515
516        {
517            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
518            writer.write_entry(&entry).unwrap();
519            writer.finish().unwrap();
520        }
521
522        let mut reader = SnapshotReader::open(&path).unwrap();
523        let got = reader.read_entry().unwrap().unwrap();
524        assert_eq!(got.expire_ms, 42_000);
525        reader.verify_footer().unwrap();
526    }
527
528    #[test]
529    fn list_entries_round_trip() {
530        let dir = temp_dir();
531        let path = dir.path().join("list.snap");
532
533        let mut deque = VecDeque::new();
534        deque.push_back(Bytes::from("a"));
535        deque.push_back(Bytes::from("b"));
536        deque.push_back(Bytes::from("c"));
537
538        let entries = vec![
539            SnapEntry {
540                key: "mylist".into(),
541                value: SnapValue::List(deque),
542                expire_ms: -1,
543            },
544            SnapEntry {
545                key: "mystr".into(),
546                value: SnapValue::String(Bytes::from("val")),
547                expire_ms: 1000,
548            },
549        ];
550
551        {
552            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
553            for entry in &entries {
554                writer.write_entry(entry).unwrap();
555            }
556            writer.finish().unwrap();
557        }
558
559        let mut reader = SnapshotReader::open(&path).unwrap();
560        let mut got = Vec::new();
561        while let Some(entry) = reader.read_entry().unwrap() {
562            got.push(entry);
563        }
564        assert_eq!(entries, got);
565        reader.verify_footer().unwrap();
566    }
567
568    #[test]
569    fn sorted_set_entries_round_trip() {
570        let dir = temp_dir();
571        let path = dir.path().join("zset.snap");
572
573        let entries = vec![
574            SnapEntry {
575                key: "board".into(),
576                value: SnapValue::SortedSet(vec![
577                    (100.0, "alice".into()),
578                    (200.0, "bob".into()),
579                    (150.0, "charlie".into()),
580                ]),
581                expire_ms: -1,
582            },
583            SnapEntry {
584                key: "mystr".into(),
585                value: SnapValue::String(Bytes::from("val")),
586                expire_ms: 1000,
587            },
588        ];
589
590        {
591            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
592            for entry in &entries {
593                writer.write_entry(entry).unwrap();
594            }
595            writer.finish().unwrap();
596        }
597
598        let mut reader = SnapshotReader::open(&path).unwrap();
599        let mut got = Vec::new();
600        while let Some(entry) = reader.read_entry().unwrap() {
601            got.push(entry);
602        }
603        assert_eq!(entries, got);
604        reader.verify_footer().unwrap();
605    }
606
607    #[test]
608    fn snapshot_path_format() {
609        let p = snapshot_path(Path::new("/data"), 5);
610        assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
611    }
612}