Skip to main content

ember_persistence/
snapshot.rs

1//! Point-in-time snapshot files.
2//!
3//! Each shard writes its own snapshot (`shard-{id}.snap`). The format
4//! stores all live entries in a single pass. Writes go to a `.tmp`
5//! file first and are atomically renamed on completion — this ensures
6//! a partial/crashed snapshot never corrupts the existing `.snap` file.
7//!
8//! File layout:
9//! ```text
10//! [ESNP magic: 4B][version: 1B][shard_id: 2B][entry_count: 4B]
11//! [entries...]
12//! [footer_crc32: 4B]
13//! ```
14//!
15//! Each entry (v2, type-tagged):
16//! ```text
17//! [key_len: 4B][key][type_tag: 1B][type-specific payload][expire_ms: 8B]
18//! ```
19//!
20//! Type tags: 0=string, 1=list, 2=sorted set.
21//! `expire_ms` is the remaining TTL in milliseconds, or -1 for no expiry.
22//! v1 entries (no type tag) are still readable for backward compatibility.
23
24use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33/// Type tags for snapshot entries.
34const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39
40/// The value stored in a snapshot entry.
41#[derive(Debug, Clone, PartialEq)]
42pub enum SnapValue {
43    /// A string value.
44    String(Bytes),
45    /// A list of values.
46    List(VecDeque<Bytes>),
47    /// A sorted set: vec of (score, member) pairs.
48    SortedSet(Vec<(f64, String)>),
49    /// A hash: map of field names to values.
50    Hash(HashMap<String, Bytes>),
51    /// An unordered set of unique string members.
52    Set(HashSet<String>),
53}
54
55/// A single entry in a snapshot file.
56#[derive(Debug, Clone, PartialEq)]
57pub struct SnapEntry {
58    pub key: String,
59    pub value: SnapValue,
60    /// Remaining TTL in milliseconds, or -1 for no expiration.
61    pub expire_ms: i64,
62}
63
64/// Writes a complete snapshot to disk.
65///
66/// Entries are written to a temporary file first, then atomically
67/// renamed to the final path. The caller provides an iterator over
68/// the entries to write.
69pub struct SnapshotWriter {
70    final_path: PathBuf,
71    tmp_path: PathBuf,
72    writer: BufWriter<File>,
73    /// Running CRC over all entry bytes for the footer checksum.
74    hasher: crc32fast::Hasher,
75    count: u32,
76    /// Set to `true` after a successful `finish()`. Used by the `Drop`
77    /// impl to clean up incomplete temp files.
78    finished: bool,
79}
80
81impl SnapshotWriter {
82    /// Creates a new snapshot writer. The file won't appear at `path`
83    /// until [`finish`] is called successfully.
84    pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
85        let final_path = path.into();
86        let tmp_path = final_path.with_extension("snap.tmp");
87
88        let file = File::create(&tmp_path)?;
89        let mut writer = BufWriter::new(file);
90
91        // write header: magic + version + shard_id + placeholder entry count
92        format::write_header(&mut writer, format::SNAP_MAGIC)?;
93        format::write_u16(&mut writer, shard_id)?;
94        // entry count — we'll seek back and update, or just write it now
95        // and track. since we're streaming, write 0 and update after.
96        format::write_u32(&mut writer, 0)?;
97
98        Ok(Self {
99            final_path,
100            tmp_path,
101            writer,
102            hasher: crc32fast::Hasher::new(),
103            count: 0,
104            finished: false,
105        })
106    }
107
108    /// Writes a single entry to the snapshot.
109    pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
110        let mut buf = Vec::new();
111        format::write_bytes(&mut buf, entry.key.as_bytes())?;
112        match &entry.value {
113            SnapValue::String(data) => {
114                format::write_u8(&mut buf, TYPE_STRING)?;
115                format::write_bytes(&mut buf, data)?;
116            }
117            SnapValue::List(deque) => {
118                format::write_u8(&mut buf, TYPE_LIST)?;
119                format::write_u32(&mut buf, deque.len() as u32)?;
120                for item in deque {
121                    format::write_bytes(&mut buf, item)?;
122                }
123            }
124            SnapValue::SortedSet(members) => {
125                format::write_u8(&mut buf, TYPE_SORTED_SET)?;
126                format::write_u32(&mut buf, members.len() as u32)?;
127                for (score, member) in members {
128                    format::write_f64(&mut buf, *score)?;
129                    format::write_bytes(&mut buf, member.as_bytes())?;
130                }
131            }
132            SnapValue::Hash(map) => {
133                format::write_u8(&mut buf, TYPE_HASH)?;
134                format::write_u32(&mut buf, map.len() as u32)?;
135                for (field, value) in map {
136                    format::write_bytes(&mut buf, field.as_bytes())?;
137                    format::write_bytes(&mut buf, value)?;
138                }
139            }
140            SnapValue::Set(set) => {
141                format::write_u8(&mut buf, TYPE_SET)?;
142                format::write_u32(&mut buf, set.len() as u32)?;
143                for member in set {
144                    format::write_bytes(&mut buf, member.as_bytes())?;
145                }
146            }
147        }
148        format::write_i64(&mut buf, entry.expire_ms)?;
149
150        self.hasher.update(&buf);
151        self.writer.write_all(&buf)?;
152        self.count += 1;
153        Ok(())
154    }
155
156    /// Finalizes the snapshot: writes the footer CRC, flushes, and
157    /// atomically renames the temp file to the final path.
158    pub fn finish(mut self) -> Result<(), FormatError> {
159        // write footer CRC — clone the hasher so we don't move out of self
160        let checksum = self.hasher.clone().finalize();
161        format::write_u32(&mut self.writer, checksum)?;
162        self.writer.flush()?;
163        self.writer.get_ref().sync_all()?;
164
165        // rewrite the header with the correct entry count.
166        // open a second handle for the seek — the BufWriter is already
167        // flushed and synced above.
168        {
169            use std::io::{Seek, SeekFrom};
170            let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
171            // header: 4 (magic) + 1 (version) + 2 (shard_id) = 7 bytes
172            file.seek(SeekFrom::Start(7))?;
173            format::write_u32(&mut file, self.count)?;
174            file.sync_all()?;
175        }
176
177        // atomic rename
178        fs::rename(&self.tmp_path, &self.final_path)?;
179        self.finished = true;
180        Ok(())
181    }
182}
183
184impl Drop for SnapshotWriter {
185    fn drop(&mut self) {
186        if !self.finished {
187            // best-effort cleanup of incomplete temp file
188            let _ = fs::remove_file(&self.tmp_path);
189        }
190    }
191}
192
193/// Reads entries from a snapshot file.
194pub struct SnapshotReader {
195    reader: BufReader<File>,
196    pub shard_id: u16,
197    pub entry_count: u32,
198    read_so_far: u32,
199    hasher: crc32fast::Hasher,
200    /// Format version — v1 has no type tags, v2 has type-tagged entries.
201    version: u8,
202}
203
204impl SnapshotReader {
205    /// Opens a snapshot file and reads the header.
206    pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
207        let file = File::open(path.as_ref())?;
208        let mut reader = BufReader::new(file);
209
210        let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
211        let shard_id = format::read_u16(&mut reader)?;
212        let entry_count = format::read_u32(&mut reader)?;
213
214        Ok(Self {
215            reader,
216            shard_id,
217            entry_count,
218            read_so_far: 0,
219            hasher: crc32fast::Hasher::new(),
220            version,
221        })
222    }
223
224    /// Reads the next entry. Returns `None` when all entries have been read.
225    pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
226        if self.read_so_far >= self.entry_count {
227            return Ok(None);
228        }
229
230        let mut buf = Vec::new();
231
232        let key_bytes = format::read_bytes(&mut self.reader)?;
233        format::write_bytes(&mut buf, &key_bytes).expect("vec write");
234
235        let value = if self.version == 1 {
236            // v1: no type tag, value is always a string
237            let value_bytes = format::read_bytes(&mut self.reader)?;
238            format::write_bytes(&mut buf, &value_bytes).expect("vec write");
239            SnapValue::String(Bytes::from(value_bytes))
240        } else {
241            // v2+: type-tagged values
242            let type_tag = format::read_u8(&mut self.reader)?;
243            format::write_u8(&mut buf, type_tag).expect("vec write");
244            match type_tag {
245                TYPE_STRING => {
246                    let value_bytes = format::read_bytes(&mut self.reader)?;
247                    format::write_bytes(&mut buf, &value_bytes).expect("vec write");
248                    SnapValue::String(Bytes::from(value_bytes))
249                }
250                TYPE_LIST => {
251                    let count = format::read_u32(&mut self.reader)?;
252                    format::write_u32(&mut buf, count).expect("vec write");
253                    let mut deque = VecDeque::with_capacity(count as usize);
254                    for _ in 0..count {
255                        let item = format::read_bytes(&mut self.reader)?;
256                        format::write_bytes(&mut buf, &item).expect("vec write");
257                        deque.push_back(Bytes::from(item));
258                    }
259                    SnapValue::List(deque)
260                }
261                TYPE_SORTED_SET => {
262                    let count = format::read_u32(&mut self.reader)?;
263                    format::write_u32(&mut buf, count).expect("vec write");
264                    let mut members = Vec::with_capacity(count as usize);
265                    for _ in 0..count {
266                        let score = format::read_f64(&mut self.reader)?;
267                        format::write_f64(&mut buf, score).expect("vec write");
268                        let member_bytes = format::read_bytes(&mut self.reader)?;
269                        format::write_bytes(&mut buf, &member_bytes).expect("vec write");
270                        let member = String::from_utf8(member_bytes).map_err(|_| {
271                            FormatError::Io(io::Error::new(
272                                io::ErrorKind::InvalidData,
273                                "member is not valid utf-8",
274                            ))
275                        })?;
276                        members.push((score, member));
277                    }
278                    SnapValue::SortedSet(members)
279                }
280                TYPE_HASH => {
281                    let count = format::read_u32(&mut self.reader)?;
282                    format::write_u32(&mut buf, count).expect("vec write");
283                    let mut map = HashMap::with_capacity(count as usize);
284                    for _ in 0..count {
285                        let field_bytes = format::read_bytes(&mut self.reader)?;
286                        format::write_bytes(&mut buf, &field_bytes).expect("vec write");
287                        let field = String::from_utf8(field_bytes).map_err(|_| {
288                            FormatError::Io(io::Error::new(
289                                io::ErrorKind::InvalidData,
290                                "hash field is not valid utf-8",
291                            ))
292                        })?;
293                        let value_bytes = format::read_bytes(&mut self.reader)?;
294                        format::write_bytes(&mut buf, &value_bytes).expect("vec write");
295                        map.insert(field, Bytes::from(value_bytes));
296                    }
297                    SnapValue::Hash(map)
298                }
299                TYPE_SET => {
300                    let count = format::read_u32(&mut self.reader)?;
301                    format::write_u32(&mut buf, count).expect("vec write");
302                    let mut set = HashSet::with_capacity(count as usize);
303                    for _ in 0..count {
304                        let member_bytes = format::read_bytes(&mut self.reader)?;
305                        format::write_bytes(&mut buf, &member_bytes).expect("vec write");
306                        let member = String::from_utf8(member_bytes).map_err(|_| {
307                            FormatError::Io(io::Error::new(
308                                io::ErrorKind::InvalidData,
309                                "set member is not valid utf-8",
310                            ))
311                        })?;
312                        set.insert(member);
313                    }
314                    SnapValue::Set(set)
315                }
316                _ => {
317                    return Err(FormatError::UnknownTag(type_tag));
318                }
319            }
320        };
321
322        let expire_ms = format::read_i64(&mut self.reader)?;
323        format::write_i64(&mut buf, expire_ms).expect("vec write");
324        self.hasher.update(&buf);
325
326        let key = String::from_utf8(key_bytes).map_err(|_| {
327            FormatError::Io(io::Error::new(
328                io::ErrorKind::InvalidData,
329                "key is not valid utf-8",
330            ))
331        })?;
332
333        self.read_so_far += 1;
334        Ok(Some(SnapEntry {
335            key,
336            value,
337            expire_ms,
338        }))
339    }
340
341    /// Verifies the footer CRC32 after all entries have been read.
342    /// Must be called after reading all entries.
343    pub fn verify_footer(self) -> Result<(), FormatError> {
344        let expected = self.hasher.finalize();
345        let mut reader = self.reader;
346        let stored = format::read_u32(&mut reader)?;
347        format::verify_crc32_values(expected, stored)
348    }
349}
350
351/// Returns the snapshot file path for a given shard in a data directory.
352pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
353    data_dir.join(format!("shard-{shard_id}.snap"))
354}
355
356#[cfg(test)]
357mod tests {
358    use super::*;
359
360    fn temp_dir() -> tempfile::TempDir {
361        tempfile::tempdir().expect("create temp dir")
362    }
363
364    #[test]
365    fn empty_snapshot_round_trip() {
366        let dir = temp_dir();
367        let path = dir.path().join("empty.snap");
368
369        {
370            let writer = SnapshotWriter::create(&path, 0).unwrap();
371            writer.finish().unwrap();
372        }
373
374        let reader = SnapshotReader::open(&path).unwrap();
375        assert_eq!(reader.shard_id, 0);
376        assert_eq!(reader.entry_count, 0);
377        reader.verify_footer().unwrap();
378    }
379
380    #[test]
381    fn entries_round_trip() {
382        let dir = temp_dir();
383        let path = dir.path().join("data.snap");
384
385        let entries = vec![
386            SnapEntry {
387                key: "hello".into(),
388                value: SnapValue::String(Bytes::from("world")),
389                expire_ms: -1,
390            },
391            SnapEntry {
392                key: "ttl".into(),
393                value: SnapValue::String(Bytes::from("expiring")),
394                expire_ms: 5000,
395            },
396            SnapEntry {
397                key: "empty".into(),
398                value: SnapValue::String(Bytes::new()),
399                expire_ms: -1,
400            },
401        ];
402
403        {
404            let mut writer = SnapshotWriter::create(&path, 7).unwrap();
405            for entry in &entries {
406                writer.write_entry(entry).unwrap();
407            }
408            writer.finish().unwrap();
409        }
410
411        let mut reader = SnapshotReader::open(&path).unwrap();
412        assert_eq!(reader.shard_id, 7);
413        assert_eq!(reader.entry_count, 3);
414
415        let mut got = Vec::new();
416        while let Some(entry) = reader.read_entry().unwrap() {
417            got.push(entry);
418        }
419        assert_eq!(entries, got);
420        reader.verify_footer().unwrap();
421    }
422
423    #[test]
424    fn corrupt_footer_detected() {
425        let dir = temp_dir();
426        let path = dir.path().join("corrupt.snap");
427
428        {
429            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
430            writer
431                .write_entry(&SnapEntry {
432                    key: "k".into(),
433                    value: SnapValue::String(Bytes::from("v")),
434                    expire_ms: -1,
435                })
436                .unwrap();
437            writer.finish().unwrap();
438        }
439
440        // corrupt the last byte (footer CRC)
441        let mut data = fs::read(&path).unwrap();
442        let last = data.len() - 1;
443        data[last] ^= 0xFF;
444        fs::write(&path, &data).unwrap();
445
446        let mut reader = SnapshotReader::open(&path).unwrap();
447        // reading entries should still work
448        reader.read_entry().unwrap();
449        // but footer verification should fail
450        let err = reader.verify_footer().unwrap_err();
451        assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
452    }
453
454    #[test]
455    fn atomic_rename_prevents_partial_snapshots() {
456        let dir = temp_dir();
457        let path = dir.path().join("atomic.snap");
458
459        // write an initial snapshot
460        {
461            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
462            writer
463                .write_entry(&SnapEntry {
464                    key: "original".into(),
465                    value: SnapValue::String(Bytes::from("data")),
466                    expire_ms: -1,
467                })
468                .unwrap();
469            writer.finish().unwrap();
470        }
471
472        // start a second snapshot but don't finish it
473        {
474            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
475            writer
476                .write_entry(&SnapEntry {
477                    key: "new".into(),
478                    value: SnapValue::String(Bytes::from("partial")),
479                    expire_ms: -1,
480                })
481                .unwrap();
482            // drop without finish — simulates a crash
483            drop(writer);
484        }
485
486        // the original snapshot should still be intact
487        let mut reader = SnapshotReader::open(&path).unwrap();
488        let entry = reader.read_entry().unwrap().unwrap();
489        assert_eq!(entry.key, "original");
490
491        // the tmp file should have been cleaned up by Drop
492        let tmp = path.with_extension("snap.tmp");
493        assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
494    }
495
496    #[test]
497    fn ttl_entries_preserved() {
498        let dir = temp_dir();
499        let path = dir.path().join("ttl.snap");
500
501        let entry = SnapEntry {
502            key: "expires".into(),
503            value: SnapValue::String(Bytes::from("soon")),
504            expire_ms: 42_000,
505        };
506
507        {
508            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
509            writer.write_entry(&entry).unwrap();
510            writer.finish().unwrap();
511        }
512
513        let mut reader = SnapshotReader::open(&path).unwrap();
514        let got = reader.read_entry().unwrap().unwrap();
515        assert_eq!(got.expire_ms, 42_000);
516        reader.verify_footer().unwrap();
517    }
518
519    #[test]
520    fn list_entries_round_trip() {
521        let dir = temp_dir();
522        let path = dir.path().join("list.snap");
523
524        let mut deque = VecDeque::new();
525        deque.push_back(Bytes::from("a"));
526        deque.push_back(Bytes::from("b"));
527        deque.push_back(Bytes::from("c"));
528
529        let entries = vec![
530            SnapEntry {
531                key: "mylist".into(),
532                value: SnapValue::List(deque),
533                expire_ms: -1,
534            },
535            SnapEntry {
536                key: "mystr".into(),
537                value: SnapValue::String(Bytes::from("val")),
538                expire_ms: 1000,
539            },
540        ];
541
542        {
543            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
544            for entry in &entries {
545                writer.write_entry(entry).unwrap();
546            }
547            writer.finish().unwrap();
548        }
549
550        let mut reader = SnapshotReader::open(&path).unwrap();
551        let mut got = Vec::new();
552        while let Some(entry) = reader.read_entry().unwrap() {
553            got.push(entry);
554        }
555        assert_eq!(entries, got);
556        reader.verify_footer().unwrap();
557    }
558
559    #[test]
560    fn sorted_set_entries_round_trip() {
561        let dir = temp_dir();
562        let path = dir.path().join("zset.snap");
563
564        let entries = vec![
565            SnapEntry {
566                key: "board".into(),
567                value: SnapValue::SortedSet(vec![
568                    (100.0, "alice".into()),
569                    (200.0, "bob".into()),
570                    (150.0, "charlie".into()),
571                ]),
572                expire_ms: -1,
573            },
574            SnapEntry {
575                key: "mystr".into(),
576                value: SnapValue::String(Bytes::from("val")),
577                expire_ms: 1000,
578            },
579        ];
580
581        {
582            let mut writer = SnapshotWriter::create(&path, 0).unwrap();
583            for entry in &entries {
584                writer.write_entry(entry).unwrap();
585            }
586            writer.finish().unwrap();
587        }
588
589        let mut reader = SnapshotReader::open(&path).unwrap();
590        let mut got = Vec::new();
591        while let Some(entry) = reader.read_entry().unwrap() {
592            got.push(entry);
593        }
594        assert_eq!(entries, got);
595        reader.verify_footer().unwrap();
596    }
597
598    #[test]
599    fn snapshot_path_format() {
600        let p = snapshot_path(Path::new("/data"), 5);
601        assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
602    }
603}