Skip to main content

ember_persistence/
snapshot.rs

1//! Point-in-time snapshot files.
2//!
3//! Each shard writes its own snapshot (`shard-{id}.snap`). The format
4//! stores all live entries in a single pass. Writes go to a `.tmp`
5//! file first and are atomically renamed on completion — this ensures
6//! a partial/crashed snapshot never corrupts the existing `.snap` file.
7//!
8//! File layout:
9//! ```text
10//! [ESNP magic: 4B][version: 1B][shard_id: 2B][entry_count: 4B]
11//! [entries...]
12//! [footer_crc32: 4B]
13//! ```
14//!
15//! Each entry (v2, type-tagged):
16//! ```text
17//! [key_len: 4B][key][type_tag: 1B][type-specific payload][expire_ms: 8B]
18//! ```
19//!
20//! Type tags: 0=string, 1=list, 2=sorted set.
21//! `expire_ms` is the remaining TTL in milliseconds, or -1 for no expiry.
22//! v1 entries (no type tag) are still readable for backward compatibility.
23
24use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File, OpenOptions};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33/// Type tags for snapshot entries.
34const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39#[cfg(feature = "vector")]
40const TYPE_VECTOR: u8 = 6;
41#[cfg(feature = "protobuf")]
42const TYPE_PROTO: u8 = 5;
43
44/// Converts raw bytes to a UTF-8 string, returning a descriptive error
45/// on invalid data. `field` names the field for the error message
46/// (e.g. "key", "member", "hash field").
47fn parse_utf8(bytes: Vec<u8>, field: &str) -> Result<String, FormatError> {
48    String::from_utf8(bytes).map_err(|_| {
49        FormatError::Io(io::Error::new(
50            io::ErrorKind::InvalidData,
51            format!("{field} is not valid utf-8"),
52        ))
53    })
54}
55
56/// Reads a UTF-8 string from a length-prefixed byte field.
57#[cfg(feature = "encryption")]
58fn read_snap_string(r: &mut impl io::Read, field: &str) -> Result<String, FormatError> {
59    let bytes = format::read_bytes(r)?;
60    parse_utf8(bytes, field)
61}
62
63/// Parses a type-tagged SnapValue from a reader (v2+ format).
64///
65/// Used by `read_encrypted_entry` to parse the `[type_tag][payload]`
66/// portion of a decrypted entry. The plaintext path has parallel logic
67/// but interleaves CRC buffer mirroring, so it stays inline.
68#[cfg(feature = "encryption")]
69fn parse_snap_value(r: &mut impl io::Read) -> Result<SnapValue, FormatError> {
70    let type_tag = format::read_u8(r)?;
71    match type_tag {
72        TYPE_STRING => {
73            let v = format::read_bytes(r)?;
74            Ok(SnapValue::String(Bytes::from(v)))
75        }
76        TYPE_LIST => {
77            let count = format::read_u32(r)?;
78            format::validate_collection_count(count, "list")?;
79            let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
80            for _ in 0..count {
81                deque.push_back(Bytes::from(format::read_bytes(r)?));
82            }
83            Ok(SnapValue::List(deque))
84        }
85        TYPE_SORTED_SET => {
86            let count = format::read_u32(r)?;
87            format::validate_collection_count(count, "sorted set")?;
88            let mut members = Vec::with_capacity(format::capped_capacity(count));
89            for _ in 0..count {
90                let score = format::read_f64(r)?;
91                let member = read_snap_string(r, "member")?;
92                members.push((score, member));
93            }
94            Ok(SnapValue::SortedSet(members))
95        }
96        TYPE_HASH => {
97            let count = format::read_u32(r)?;
98            format::validate_collection_count(count, "hash")?;
99            let mut map = HashMap::with_capacity(format::capped_capacity(count));
100            for _ in 0..count {
101                let field = read_snap_string(r, "hash field")?;
102                let value = format::read_bytes(r)?;
103                map.insert(field, Bytes::from(value));
104            }
105            Ok(SnapValue::Hash(map))
106        }
107        TYPE_SET => {
108            let count = format::read_u32(r)?;
109            format::validate_collection_count(count, "set")?;
110            let mut set = HashSet::with_capacity(format::capped_capacity(count));
111            for _ in 0..count {
112                let member = read_snap_string(r, "set member")?;
113                set.insert(member);
114            }
115            Ok(SnapValue::Set(set))
116        }
117        #[cfg(feature = "vector")]
118        TYPE_VECTOR => {
119            let metric = format::read_u8(r)?;
120            if metric > 2 {
121                return Err(FormatError::InvalidData(format!(
122                    "unknown vector metric: {metric}"
123                )));
124            }
125            let quantization = format::read_u8(r)?;
126            if quantization > 2 {
127                return Err(FormatError::InvalidData(format!(
128                    "unknown vector quantization: {quantization}"
129                )));
130            }
131            let connectivity = format::read_u32(r)?;
132            let expansion_add = format::read_u32(r)?;
133            let dim = format::read_u32(r)?;
134            if dim > format::MAX_PERSISTED_VECTOR_DIMS {
135                return Err(FormatError::InvalidData(format!(
136                    "vector dimension {dim} exceeds max {}",
137                    format::MAX_PERSISTED_VECTOR_DIMS
138                )));
139            }
140            let count = format::read_u32(r)?;
141            if count > format::MAX_PERSISTED_VECTOR_COUNT {
142                return Err(FormatError::InvalidData(format!(
143                    "vector element count {count} exceeds max {}",
144                    format::MAX_PERSISTED_VECTOR_COUNT
145                )));
146            }
147            format::validate_vector_total(dim, count)?;
148            let mut elements = Vec::with_capacity(format::capped_capacity(count));
149            for _ in 0..count {
150                let name = read_snap_string(r, "vector element name")?;
151                let mut vector = Vec::with_capacity(dim as usize);
152                for _ in 0..dim {
153                    vector.push(format::read_f32(r)?);
154                }
155                elements.push((name, vector));
156            }
157            Ok(SnapValue::Vector {
158                metric,
159                quantization,
160                connectivity,
161                expansion_add,
162                dim,
163                elements,
164            })
165        }
166        #[cfg(feature = "protobuf")]
167        TYPE_PROTO => {
168            let type_name = read_snap_string(r, "proto type_name")?;
169            let data = format::read_bytes(r)?;
170            Ok(SnapValue::Proto {
171                type_name,
172                data: Bytes::from(data),
173            })
174        }
175        _ => Err(FormatError::UnknownTag(type_tag)),
176    }
177}
178
179/// The value stored in a snapshot entry.
180#[derive(Debug, Clone, PartialEq)]
181pub enum SnapValue {
182    /// A string value.
183    String(Bytes),
184    /// A list of values.
185    List(VecDeque<Bytes>),
186    /// A sorted set: vec of (score, member) pairs.
187    SortedSet(Vec<(f64, String)>),
188    /// A hash: map of field names to values.
189    Hash(HashMap<String, Bytes>),
190    /// An unordered set of unique string members.
191    Set(HashSet<String>),
192    /// A vector set: index config + all (element, vector) pairs.
193    #[cfg(feature = "vector")]
194    Vector {
195        metric: u8,
196        quantization: u8,
197        connectivity: u32,
198        expansion_add: u32,
199        dim: u32,
200        elements: Vec<(String, Vec<f32>)>,
201    },
202    /// A protobuf message: type name + serialized bytes.
203    #[cfg(feature = "protobuf")]
204    Proto { type_name: String, data: Bytes },
205}
206
207/// A single entry in a snapshot file.
208#[derive(Debug, Clone, PartialEq)]
209pub struct SnapEntry {
210    pub key: String,
211    pub value: SnapValue,
212    /// Remaining TTL in milliseconds, or -1 for no expiration.
213    pub expire_ms: i64,
214}
215
216impl SnapEntry {
217    /// Estimates the serialized byte size for buffer pre-allocation.
218    fn estimated_size(&self) -> usize {
219        const LEN_PREFIX: usize = 4;
220
221        let key_size = LEN_PREFIX + self.key.len();
222        let value_size = match &self.value {
223            SnapValue::String(data) => 1 + LEN_PREFIX + data.len(),
224            SnapValue::List(deque) => {
225                let items: usize = deque.iter().map(|v| LEN_PREFIX + v.len()).sum();
226                1 + 4 + items
227            }
228            SnapValue::SortedSet(members) => {
229                let items: usize = members.iter().map(|(_, m)| 8 + LEN_PREFIX + m.len()).sum();
230                1 + 4 + items
231            }
232            SnapValue::Hash(map) => {
233                let items: usize = map
234                    .iter()
235                    .map(|(f, v)| LEN_PREFIX + f.len() + LEN_PREFIX + v.len())
236                    .sum();
237                1 + 4 + items
238            }
239            SnapValue::Set(set) => {
240                let items: usize = set.iter().map(|m| LEN_PREFIX + m.len()).sum();
241                1 + 4 + items
242            }
243            #[cfg(feature = "vector")]
244            SnapValue::Vector { dim, elements, .. } => {
245                let items: usize = elements
246                    .iter()
247                    .map(|(name, _)| LEN_PREFIX + name.len() + (*dim as usize) * 4)
248                    .sum();
249                // tag + metric + quant + connectivity + expansion + dim + count + items
250                1 + 2 + 4 + 4 + 4 + 4 + items
251            }
252            #[cfg(feature = "protobuf")]
253            SnapValue::Proto { type_name, data } => {
254                1 + LEN_PREFIX + type_name.len() + LEN_PREFIX + data.len()
255            }
256        };
257        // key + value + expire_ms (i64 = 8 bytes)
258        key_size + value_size + 8
259    }
260}
261
262/// Writes a complete snapshot to disk.
263///
264/// Entries are written to a temporary file first, then atomically
265/// renamed to the final path. The caller provides an iterator over
266/// the entries to write.
267pub struct SnapshotWriter {
268    final_path: PathBuf,
269    tmp_path: PathBuf,
270    writer: BufWriter<File>,
271    /// Running CRC over all entry bytes for the footer checksum.
272    hasher: crc32fast::Hasher,
273    count: u32,
274    /// Set to `true` after a successful `finish()`. Used by the `Drop`
275    /// impl to clean up incomplete temp files.
276    finished: bool,
277    #[cfg(feature = "encryption")]
278    encryption_key: Option<crate::encryption::EncryptionKey>,
279}
280
281impl SnapshotWriter {
282    /// Creates a new snapshot writer. The file won't appear at `path`
283    /// until [`Self::finish`] is called successfully.
284    pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
285        let final_path = path.into();
286        let (tmp_path, writer) = Self::open_tmp(&final_path)?;
287        let mut writer = BufWriter::new(writer);
288
289        format::write_header(&mut writer, format::SNAP_MAGIC)?;
290        format::write_u16(&mut writer, shard_id)?;
291        format::write_u32(&mut writer, 0)?;
292
293        Ok(Self {
294            final_path,
295            tmp_path,
296            writer,
297            hasher: crc32fast::Hasher::new(),
298            count: 0,
299            finished: false,
300            #[cfg(feature = "encryption")]
301            encryption_key: None,
302        })
303    }
304
305    /// Creates a new encrypted snapshot writer.
306    #[cfg(feature = "encryption")]
307    pub fn create_encrypted(
308        path: impl Into<PathBuf>,
309        shard_id: u16,
310        key: crate::encryption::EncryptionKey,
311    ) -> Result<Self, FormatError> {
312        let final_path = path.into();
313        let (tmp_path, file) = Self::open_tmp(&final_path)?;
314        let mut writer = BufWriter::new(file);
315
316        format::write_header_versioned(
317            &mut writer,
318            format::SNAP_MAGIC,
319            format::FORMAT_VERSION_ENCRYPTED,
320        )?;
321        format::write_u16(&mut writer, shard_id)?;
322        format::write_u32(&mut writer, 0)?;
323
324        Ok(Self {
325            final_path,
326            tmp_path,
327            writer,
328            hasher: crc32fast::Hasher::new(),
329            count: 0,
330            finished: false,
331            encryption_key: Some(key),
332        })
333    }
334
335    /// Opens the temp file for writing.
336    fn open_tmp(final_path: &Path) -> Result<(PathBuf, File), FormatError> {
337        let tmp_path = final_path.with_extension("snap.tmp");
338        let mut opts = OpenOptions::new();
339        opts.write(true).create(true).truncate(true);
340        #[cfg(unix)]
341        {
342            use std::os::unix::fs::OpenOptionsExt;
343            opts.mode(0o600);
344        }
345        let file = opts.open(&tmp_path)?;
346        Ok((tmp_path, file))
347    }
348
349    /// Writes a single entry to the snapshot.
350    ///
351    /// When encrypted, each entry is written as `[nonce: 12B][len: 4B][ciphertext]`.
352    /// The footer CRC covers the encrypted bytes (nonce + len + ciphertext).
353    pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
354        let mut buf = Vec::with_capacity(entry.estimated_size());
355        format::write_bytes(&mut buf, entry.key.as_bytes())?;
356        match &entry.value {
357            SnapValue::String(data) => {
358                format::write_u8(&mut buf, TYPE_STRING)?;
359                format::write_bytes(&mut buf, data)?;
360            }
361            SnapValue::List(deque) => {
362                format::write_u8(&mut buf, TYPE_LIST)?;
363                format::write_len(&mut buf, deque.len())?;
364                for item in deque {
365                    format::write_bytes(&mut buf, item)?;
366                }
367            }
368            SnapValue::SortedSet(members) => {
369                format::write_u8(&mut buf, TYPE_SORTED_SET)?;
370                format::write_len(&mut buf, members.len())?;
371                for (score, member) in members {
372                    format::write_f64(&mut buf, *score)?;
373                    format::write_bytes(&mut buf, member.as_bytes())?;
374                }
375            }
376            SnapValue::Hash(map) => {
377                format::write_u8(&mut buf, TYPE_HASH)?;
378                format::write_len(&mut buf, map.len())?;
379                for (field, value) in map {
380                    format::write_bytes(&mut buf, field.as_bytes())?;
381                    format::write_bytes(&mut buf, value)?;
382                }
383            }
384            SnapValue::Set(set) => {
385                format::write_u8(&mut buf, TYPE_SET)?;
386                format::write_len(&mut buf, set.len())?;
387                for member in set {
388                    format::write_bytes(&mut buf, member.as_bytes())?;
389                }
390            }
391            #[cfg(feature = "vector")]
392            SnapValue::Vector {
393                metric,
394                quantization,
395                connectivity,
396                expansion_add,
397                dim,
398                elements,
399            } => {
400                format::write_u8(&mut buf, TYPE_VECTOR)?;
401                format::write_u8(&mut buf, *metric)?;
402                format::write_u8(&mut buf, *quantization)?;
403                format::write_u32(&mut buf, *connectivity)?;
404                format::write_u32(&mut buf, *expansion_add)?;
405                format::write_u32(&mut buf, *dim)?;
406                format::write_len(&mut buf, elements.len())?;
407                for (name, vector) in elements {
408                    format::write_bytes(&mut buf, name.as_bytes())?;
409                    for &v in vector {
410                        format::write_f32(&mut buf, v)?;
411                    }
412                }
413            }
414            #[cfg(feature = "protobuf")]
415            SnapValue::Proto { type_name, data } => {
416                format::write_u8(&mut buf, TYPE_PROTO)?;
417                format::write_bytes(&mut buf, type_name.as_bytes())?;
418                format::write_bytes(&mut buf, data)?;
419            }
420        }
421        format::write_i64(&mut buf, entry.expire_ms)?;
422
423        #[cfg(feature = "encryption")]
424        if let Some(ref key) = self.encryption_key {
425            let (nonce, ciphertext) = crate::encryption::encrypt_record(key, &buf)?;
426            let ct_len = u32::try_from(ciphertext.len()).map_err(|_| {
427                io::Error::new(
428                    io::ErrorKind::InvalidInput,
429                    "encrypted record exceeds u32::MAX bytes",
430                )
431            })?;
432            // footer CRC covers the encrypted envelope
433            self.hasher.update(&nonce);
434            let ct_len_bytes = ct_len.to_le_bytes();
435            self.hasher.update(&ct_len_bytes);
436            self.hasher.update(&ciphertext);
437            self.writer.write_all(&nonce)?;
438            format::write_u32(&mut self.writer, ct_len)?;
439            self.writer.write_all(&ciphertext)?;
440            self.count += 1;
441            return Ok(());
442        }
443
444        self.hasher.update(&buf);
445        self.writer.write_all(&buf)?;
446        self.count += 1;
447        Ok(())
448    }
449
450    /// Finalizes the snapshot: writes the footer CRC, flushes, and
451    /// atomically renames the temp file to the final path.
452    pub fn finish(mut self) -> Result<(), FormatError> {
453        // write footer CRC — clone the hasher so we don't move out of self
454        let checksum = self.hasher.clone().finalize();
455        format::write_u32(&mut self.writer, checksum)?;
456        self.writer.flush()?;
457        self.writer.get_ref().sync_all()?;
458
459        // rewrite the header with the correct entry count.
460        // open a second handle for the seek — the BufWriter is already
461        // flushed and synced above.
462        {
463            use std::io::{Seek, SeekFrom};
464            let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
465            // header: 4 (magic) + 1 (version) + 2 (shard_id) = 7 bytes
466            file.seek(SeekFrom::Start(7))?;
467            format::write_u32(&mut file, self.count)?;
468            file.sync_all()?;
469        }
470
471        // atomic rename
472        fs::rename(&self.tmp_path, &self.final_path)?;
473
474        // fsync the parent directory so the rename is durable across crashes
475        if let Some(parent) = self.final_path.parent() {
476            if let Ok(dir) = File::open(parent) {
477                let _ = dir.sync_all();
478            }
479        }
480
481        self.finished = true;
482        Ok(())
483    }
484}
485
486impl Drop for SnapshotWriter {
487    fn drop(&mut self) {
488        if !self.finished {
489            // best-effort cleanup of incomplete temp file
490            let _ = fs::remove_file(&self.tmp_path);
491        }
492    }
493}
494
495/// Reads entries from a snapshot file.
496pub struct SnapshotReader {
497    reader: BufReader<File>,
498    pub shard_id: u16,
499    pub entry_count: u32,
500    read_so_far: u32,
501    hasher: crc32fast::Hasher,
502    /// Format version — v1 has no type tags, v2 has type-tagged entries, v3 is encrypted.
503    version: u8,
504    #[cfg(feature = "encryption")]
505    encryption_key: Option<crate::encryption::EncryptionKey>,
506}
507
508impl SnapshotReader {
509    /// Opens a snapshot file and reads the header.
510    pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
511        let file = File::open(path.as_ref())?;
512        let mut reader = BufReader::new(file);
513
514        let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
515
516        if version == format::FORMAT_VERSION_ENCRYPTED {
517            return Err(FormatError::EncryptionRequired);
518        }
519
520        let shard_id = format::read_u16(&mut reader)?;
521        let entry_count = format::read_u32(&mut reader)?;
522
523        Ok(Self {
524            reader,
525            shard_id,
526            entry_count,
527            read_so_far: 0,
528            hasher: crc32fast::Hasher::new(),
529            version,
530            #[cfg(feature = "encryption")]
531            encryption_key: None,
532        })
533    }
534
535    /// Opens a snapshot file with an encryption key for decrypting v3 entries.
536    ///
537    /// Also handles v1/v2 (plaintext) files — the key is simply unused.
538    #[cfg(feature = "encryption")]
539    pub fn open_encrypted(
540        path: impl AsRef<Path>,
541        key: crate::encryption::EncryptionKey,
542    ) -> Result<Self, FormatError> {
543        let file = File::open(path.as_ref())?;
544        let mut reader = BufReader::new(file);
545
546        let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
547        let shard_id = format::read_u16(&mut reader)?;
548        let entry_count = format::read_u32(&mut reader)?;
549
550        Ok(Self {
551            reader,
552            shard_id,
553            entry_count,
554            read_so_far: 0,
555            hasher: crc32fast::Hasher::new(),
556            version,
557            encryption_key: Some(key),
558        })
559    }
560
561    /// Reads the next entry. Returns `None` when all entries have been read.
562    pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
563        if self.read_so_far >= self.entry_count {
564            return Ok(None);
565        }
566
567        #[cfg(feature = "encryption")]
568        if self.version == format::FORMAT_VERSION_ENCRYPTED {
569            return self.read_encrypted_entry();
570        }
571
572        self.read_plaintext_entry()
573    }
574
575    /// Reads a plaintext (v1/v2) entry.
576    fn read_plaintext_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
577        let mut buf = Vec::new();
578
579        let key_bytes = format::read_bytes(&mut self.reader)?;
580        format::write_bytes(&mut buf, &key_bytes)?;
581
582        let value = if self.version == 1 {
583            // v1: no type tag, value is always a string
584            let value_bytes = format::read_bytes(&mut self.reader)?;
585            format::write_bytes(&mut buf, &value_bytes)?;
586            SnapValue::String(Bytes::from(value_bytes))
587        } else {
588            // v2+: type-tagged values
589            let type_tag = format::read_u8(&mut self.reader)?;
590            format::write_u8(&mut buf, type_tag)?;
591            match type_tag {
592                TYPE_STRING => {
593                    let value_bytes = format::read_bytes(&mut self.reader)?;
594                    format::write_bytes(&mut buf, &value_bytes)?;
595                    SnapValue::String(Bytes::from(value_bytes))
596                }
597                TYPE_LIST => {
598                    let count = format::read_u32(&mut self.reader)?;
599                    format::validate_collection_count(count, "list")?;
600                    format::write_u32(&mut buf, count)?;
601                    let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
602                    for _ in 0..count {
603                        let item = format::read_bytes(&mut self.reader)?;
604                        format::write_bytes(&mut buf, &item)?;
605                        deque.push_back(Bytes::from(item));
606                    }
607                    SnapValue::List(deque)
608                }
609                TYPE_SORTED_SET => {
610                    let count = format::read_u32(&mut self.reader)?;
611                    format::validate_collection_count(count, "sorted set")?;
612                    format::write_u32(&mut buf, count)?;
613                    let mut members = Vec::with_capacity(format::capped_capacity(count));
614                    for _ in 0..count {
615                        let score = format::read_f64(&mut self.reader)?;
616                        format::write_f64(&mut buf, score)?;
617                        let member_bytes = format::read_bytes(&mut self.reader)?;
618                        format::write_bytes(&mut buf, &member_bytes)?;
619                        let member = parse_utf8(member_bytes, "member")?;
620                        members.push((score, member));
621                    }
622                    SnapValue::SortedSet(members)
623                }
624                TYPE_HASH => {
625                    let count = format::read_u32(&mut self.reader)?;
626                    format::validate_collection_count(count, "hash")?;
627                    format::write_u32(&mut buf, count)?;
628                    let mut map = HashMap::with_capacity(format::capped_capacity(count));
629                    for _ in 0..count {
630                        let field_bytes = format::read_bytes(&mut self.reader)?;
631                        format::write_bytes(&mut buf, &field_bytes)?;
632                        let field = parse_utf8(field_bytes, "hash field")?;
633                        let value_bytes = format::read_bytes(&mut self.reader)?;
634                        format::write_bytes(&mut buf, &value_bytes)?;
635                        map.insert(field, Bytes::from(value_bytes));
636                    }
637                    SnapValue::Hash(map)
638                }
639                TYPE_SET => {
640                    let count = format::read_u32(&mut self.reader)?;
641                    format::validate_collection_count(count, "set")?;
642                    format::write_u32(&mut buf, count)?;
643                    let mut set = HashSet::with_capacity(format::capped_capacity(count));
644                    for _ in 0..count {
645                        let member_bytes = format::read_bytes(&mut self.reader)?;
646                        format::write_bytes(&mut buf, &member_bytes)?;
647                        let member = parse_utf8(member_bytes, "set member")?;
648                        set.insert(member);
649                    }
650                    SnapValue::Set(set)
651                }
652                #[cfg(feature = "vector")]
653                TYPE_VECTOR => {
654                    let metric = format::read_u8(&mut self.reader)?;
655                    if metric > 2 {
656                        return Err(FormatError::InvalidData(format!(
657                            "unknown vector metric: {metric}"
658                        )));
659                    }
660                    format::write_u8(&mut buf, metric)?;
661                    let quantization = format::read_u8(&mut self.reader)?;
662                    if quantization > 2 {
663                        return Err(FormatError::InvalidData(format!(
664                            "unknown vector quantization: {quantization}"
665                        )));
666                    }
667                    format::write_u8(&mut buf, quantization)?;
668                    let connectivity = format::read_u32(&mut self.reader)?;
669                    format::write_u32(&mut buf, connectivity)?;
670                    let expansion_add = format::read_u32(&mut self.reader)?;
671                    format::write_u32(&mut buf, expansion_add)?;
672                    let dim = format::read_u32(&mut self.reader)?;
673                    if dim > format::MAX_PERSISTED_VECTOR_DIMS {
674                        return Err(FormatError::InvalidData(format!(
675                            "vector dimension {dim} exceeds max {}",
676                            format::MAX_PERSISTED_VECTOR_DIMS
677                        )));
678                    }
679                    format::write_u32(&mut buf, dim)?;
680                    let count = format::read_u32(&mut self.reader)?;
681                    if count > format::MAX_PERSISTED_VECTOR_COUNT {
682                        return Err(FormatError::InvalidData(format!(
683                            "vector element count {count} exceeds max {}",
684                            format::MAX_PERSISTED_VECTOR_COUNT
685                        )));
686                    }
687                    format::validate_vector_total(dim, count)?;
688                    format::write_u32(&mut buf, count)?;
689                    let mut elements = Vec::with_capacity(format::capped_capacity(count));
690                    for _ in 0..count {
691                        let name_bytes = format::read_bytes(&mut self.reader)?;
692                        format::write_bytes(&mut buf, &name_bytes)?;
693                        let name = parse_utf8(name_bytes, "vector element name")?;
694                        let mut vector = Vec::with_capacity(dim as usize);
695                        for _ in 0..dim {
696                            let v = format::read_f32(&mut self.reader)?;
697                            format::write_f32(&mut buf, v)?;
698                            vector.push(v);
699                        }
700                        elements.push((name, vector));
701                    }
702                    SnapValue::Vector {
703                        metric,
704                        quantization,
705                        connectivity,
706                        expansion_add,
707                        dim,
708                        elements,
709                    }
710                }
711                #[cfg(feature = "protobuf")]
712                TYPE_PROTO => {
713                    let type_name_bytes = format::read_bytes(&mut self.reader)?;
714                    format::write_bytes(&mut buf, &type_name_bytes)?;
715                    let type_name = parse_utf8(type_name_bytes, "proto type_name")?;
716                    let data = format::read_bytes(&mut self.reader)?;
717                    format::write_bytes(&mut buf, &data)?;
718                    SnapValue::Proto {
719                        type_name,
720                        data: Bytes::from(data),
721                    }
722                }
723                _ => {
724                    return Err(FormatError::UnknownTag(type_tag));
725                }
726            }
727        };
728
729        let expire_ms = format::read_i64(&mut self.reader)?;
730        format::write_i64(&mut buf, expire_ms)?;
731        self.hasher.update(&buf);
732
733        let key = parse_utf8(key_bytes, "key")?;
734
735        self.read_so_far += 1;
736        Ok(Some(SnapEntry {
737            key,
738            value,
739            expire_ms,
740        }))
741    }
742
743    /// Reads an encrypted (v3) entry: nonce + len + ciphertext.
744    /// Decrypts to get the same bytes as a plaintext entry, then parses.
745    #[cfg(feature = "encryption")]
746    fn read_encrypted_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
747        use std::io::Read as _;
748
749        let key = self
750            .encryption_key
751            .as_ref()
752            .ok_or(FormatError::EncryptionRequired)?;
753
754        let mut nonce = [0u8; crate::encryption::NONCE_SIZE];
755        self.reader
756            .read_exact(&mut nonce)
757            .map_err(|e| match e.kind() {
758                io::ErrorKind::UnexpectedEof => FormatError::UnexpectedEof,
759                _ => FormatError::Io(e),
760            })?;
761
762        let ct_len = format::read_u32(&mut self.reader)? as usize;
763        if ct_len > format::MAX_FIELD_LEN {
764            return Err(FormatError::Io(io::Error::new(
765                io::ErrorKind::InvalidData,
766                format!("encrypted entry length {ct_len} exceeds maximum"),
767            )));
768        }
769
770        let mut ciphertext = vec![0u8; ct_len];
771        self.reader
772            .read_exact(&mut ciphertext)
773            .map_err(|e| match e.kind() {
774                io::ErrorKind::UnexpectedEof => FormatError::UnexpectedEof,
775                _ => FormatError::Io(e),
776            })?;
777
778        // footer CRC covers the encrypted envelope
779        self.hasher.update(&nonce);
780        let ct_len_bytes = (ct_len as u32).to_le_bytes();
781        self.hasher.update(&ct_len_bytes);
782        self.hasher.update(&ciphertext);
783
784        let plaintext = crate::encryption::decrypt_record(key, &nonce, &ciphertext)?;
785
786        let mut cursor = io::Cursor::new(&plaintext);
787        let entry_key = read_snap_string(&mut cursor, "key")?;
788        let value = parse_snap_value(&mut cursor)?;
789        let expire_ms = format::read_i64(&mut cursor)?;
790
791        self.read_so_far += 1;
792        Ok(Some(SnapEntry {
793            key: entry_key,
794            value,
795            expire_ms,
796        }))
797    }
798
799    /// Verifies the footer CRC32 after all entries have been read.
800    /// Must be called after reading all entries.
801    pub fn verify_footer(self) -> Result<(), FormatError> {
802        let expected = self.hasher.finalize();
803        let mut reader = self.reader;
804        let stored = format::read_u32(&mut reader)?;
805        format::verify_crc32_values(expected, stored)
806    }
807}
808
809/// Returns the snapshot file path for a given shard in a data directory.
810pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
811    data_dir.join(format!("shard-{shard_id}.snap"))
812}
813
814#[cfg(test)]
815mod tests {
816    use super::*;
817
818    type Result = std::result::Result<(), Box<dyn std::error::Error>>;
819
820    fn temp_dir() -> tempfile::TempDir {
821        tempfile::tempdir().expect("create temp dir")
822    }
823
824    #[test]
825    fn empty_snapshot_round_trip() -> Result {
826        let dir = temp_dir();
827        let path = dir.path().join("empty.snap");
828
829        {
830            let writer = SnapshotWriter::create(&path, 0)?;
831            writer.finish()?;
832        }
833
834        let reader = SnapshotReader::open(&path)?;
835        assert_eq!(reader.shard_id, 0);
836        assert_eq!(reader.entry_count, 0);
837        reader.verify_footer()?;
838        Ok(())
839    }
840
841    #[test]
842    fn entries_round_trip() -> Result {
843        let dir = temp_dir();
844        let path = dir.path().join("data.snap");
845
846        let entries = vec![
847            SnapEntry {
848                key: "hello".into(),
849                value: SnapValue::String(Bytes::from("world")),
850                expire_ms: -1,
851            },
852            SnapEntry {
853                key: "ttl".into(),
854                value: SnapValue::String(Bytes::from("expiring")),
855                expire_ms: 5000,
856            },
857            SnapEntry {
858                key: "empty".into(),
859                value: SnapValue::String(Bytes::new()),
860                expire_ms: -1,
861            },
862        ];
863
864        {
865            let mut writer = SnapshotWriter::create(&path, 7)?;
866            for entry in &entries {
867                writer.write_entry(entry)?;
868            }
869            writer.finish()?;
870        }
871
872        let mut reader = SnapshotReader::open(&path)?;
873        assert_eq!(reader.shard_id, 7);
874        assert_eq!(reader.entry_count, 3);
875
876        let mut got = Vec::new();
877        while let Some(entry) = reader.read_entry()? {
878            got.push(entry);
879        }
880        assert_eq!(entries, got);
881        reader.verify_footer()?;
882        Ok(())
883    }
884
885    #[test]
886    fn corrupt_footer_detected() -> Result {
887        let dir = temp_dir();
888        let path = dir.path().join("corrupt.snap");
889
890        {
891            let mut writer = SnapshotWriter::create(&path, 0)?;
892            writer.write_entry(&SnapEntry {
893                key: "k".into(),
894                value: SnapValue::String(Bytes::from("v")),
895                expire_ms: -1,
896            })?;
897            writer.finish()?;
898        }
899
900        // corrupt the last byte (footer CRC)
901        let mut data = fs::read(&path)?;
902        let last = data.len() - 1;
903        data[last] ^= 0xFF;
904        fs::write(&path, &data)?;
905
906        let mut reader = SnapshotReader::open(&path)?;
907        // reading entries should still work
908        reader.read_entry()?;
909        // but footer verification should fail
910        let err = reader.verify_footer().unwrap_err();
911        assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
912        Ok(())
913    }
914
915    #[test]
916    fn atomic_rename_prevents_partial_snapshots() -> Result {
917        let dir = temp_dir();
918        let path = dir.path().join("atomic.snap");
919
920        // write an initial snapshot
921        {
922            let mut writer = SnapshotWriter::create(&path, 0)?;
923            writer.write_entry(&SnapEntry {
924                key: "original".into(),
925                value: SnapValue::String(Bytes::from("data")),
926                expire_ms: -1,
927            })?;
928            writer.finish()?;
929        }
930
931        // start a second snapshot but don't finish it
932        {
933            let mut writer = SnapshotWriter::create(&path, 0)?;
934            writer.write_entry(&SnapEntry {
935                key: "new".into(),
936                value: SnapValue::String(Bytes::from("partial")),
937                expire_ms: -1,
938            })?;
939            // drop without finish — simulates a crash
940            drop(writer);
941        }
942
943        // the original snapshot should still be intact
944        let mut reader = SnapshotReader::open(&path)?;
945        let entry = reader.read_entry()?.unwrap();
946        assert_eq!(entry.key, "original");
947
948        // the tmp file should have been cleaned up by Drop
949        let tmp = path.with_extension("snap.tmp");
950        assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
951        Ok(())
952    }
953
954    #[test]
955    fn ttl_entries_preserved() -> Result {
956        let dir = temp_dir();
957        let path = dir.path().join("ttl.snap");
958
959        let entry = SnapEntry {
960            key: "expires".into(),
961            value: SnapValue::String(Bytes::from("soon")),
962            expire_ms: 42_000,
963        };
964
965        {
966            let mut writer = SnapshotWriter::create(&path, 0)?;
967            writer.write_entry(&entry)?;
968            writer.finish()?;
969        }
970
971        let mut reader = SnapshotReader::open(&path)?;
972        let got = reader.read_entry()?.unwrap();
973        assert_eq!(got.expire_ms, 42_000);
974        reader.verify_footer()?;
975        Ok(())
976    }
977
978    #[test]
979    fn list_entries_round_trip() -> Result {
980        let dir = temp_dir();
981        let path = dir.path().join("list.snap");
982
983        let mut deque = VecDeque::new();
984        deque.push_back(Bytes::from("a"));
985        deque.push_back(Bytes::from("b"));
986        deque.push_back(Bytes::from("c"));
987
988        let entries = vec![
989            SnapEntry {
990                key: "mylist".into(),
991                value: SnapValue::List(deque),
992                expire_ms: -1,
993            },
994            SnapEntry {
995                key: "mystr".into(),
996                value: SnapValue::String(Bytes::from("val")),
997                expire_ms: 1000,
998            },
999        ];
1000
1001        {
1002            let mut writer = SnapshotWriter::create(&path, 0)?;
1003            for entry in &entries {
1004                writer.write_entry(entry)?;
1005            }
1006            writer.finish()?;
1007        }
1008
1009        let mut reader = SnapshotReader::open(&path)?;
1010        let mut got = Vec::new();
1011        while let Some(entry) = reader.read_entry()? {
1012            got.push(entry);
1013        }
1014        assert_eq!(entries, got);
1015        reader.verify_footer()?;
1016        Ok(())
1017    }
1018
1019    #[test]
1020    fn sorted_set_entries_round_trip() -> Result {
1021        let dir = temp_dir();
1022        let path = dir.path().join("zset.snap");
1023
1024        let entries = vec![
1025            SnapEntry {
1026                key: "board".into(),
1027                value: SnapValue::SortedSet(vec![
1028                    (100.0, "alice".into()),
1029                    (200.0, "bob".into()),
1030                    (150.0, "charlie".into()),
1031                ]),
1032                expire_ms: -1,
1033            },
1034            SnapEntry {
1035                key: "mystr".into(),
1036                value: SnapValue::String(Bytes::from("val")),
1037                expire_ms: 1000,
1038            },
1039        ];
1040
1041        {
1042            let mut writer = SnapshotWriter::create(&path, 0)?;
1043            for entry in &entries {
1044                writer.write_entry(entry)?;
1045            }
1046            writer.finish()?;
1047        }
1048
1049        let mut reader = SnapshotReader::open(&path)?;
1050        let mut got = Vec::new();
1051        while let Some(entry) = reader.read_entry()? {
1052            got.push(entry);
1053        }
1054        assert_eq!(entries, got);
1055        reader.verify_footer()?;
1056        Ok(())
1057    }
1058
1059    #[test]
1060    fn snapshot_path_format() {
1061        let p = snapshot_path(Path::new("/data"), 5);
1062        assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
1063    }
1064
1065    #[test]
1066    fn truncated_snapshot_detected() -> Result {
1067        let dir = temp_dir();
1068        let path = dir.path().join("truncated.snap");
1069
1070        // write a valid 2-entry snapshot
1071        {
1072            let mut writer = SnapshotWriter::create(&path, 0)?;
1073            writer.write_entry(&SnapEntry {
1074                key: "a".into(),
1075                value: SnapValue::String(Bytes::from("1")),
1076                expire_ms: -1,
1077            })?;
1078            writer.write_entry(&SnapEntry {
1079                key: "b".into(),
1080                value: SnapValue::String(Bytes::from("2")),
1081                expire_ms: 5000,
1082            })?;
1083            writer.finish()?;
1084        }
1085
1086        // truncate the file mid-way through the second entry
1087        let data = fs::read(&path)?;
1088        let truncated = &data[..data.len() - 20];
1089        fs::write(&path, truncated)?;
1090
1091        let mut reader = SnapshotReader::open(&path)?;
1092        assert_eq!(reader.entry_count, 2);
1093
1094        // first entry should still be readable
1095        let first = reader.read_entry()?;
1096        assert!(first.is_some());
1097
1098        // second entry should fail with an EOF-related error
1099        let err = reader.read_entry().unwrap_err();
1100        assert!(
1101            matches!(err, FormatError::UnexpectedEof | FormatError::Io(_)),
1102            "expected EOF error, got {err:?}"
1103        );
1104        Ok(())
1105    }
1106
1107    #[cfg(feature = "vector")]
1108    #[test]
1109    fn vector_entries_round_trip() -> Result {
1110        let dir = temp_dir();
1111        let path = dir.path().join("vec.snap");
1112
1113        let entries = vec![SnapEntry {
1114            key: "embeddings".into(),
1115            value: SnapValue::Vector {
1116                metric: 0,
1117                quantization: 0,
1118                connectivity: 16,
1119                expansion_add: 64,
1120                dim: 3,
1121                elements: vec![
1122                    ("doc1".into(), vec![0.1, 0.2, 0.3]),
1123                    ("doc2".into(), vec![0.4, 0.5, 0.6]),
1124                ],
1125            },
1126            expire_ms: -1,
1127        }];
1128
1129        {
1130            let mut writer = SnapshotWriter::create(&path, 0)?;
1131            for entry in &entries {
1132                writer.write_entry(entry)?;
1133            }
1134            writer.finish()?;
1135        }
1136
1137        let mut reader = SnapshotReader::open(&path)?;
1138        let mut got = Vec::new();
1139        while let Some(entry) = reader.read_entry()? {
1140            got.push(entry);
1141        }
1142        assert_eq!(entries, got);
1143        reader.verify_footer()?;
1144        Ok(())
1145    }
1146
1147    #[cfg(feature = "vector")]
1148    #[test]
1149    fn vector_empty_set_round_trip() -> Result {
1150        let dir = temp_dir();
1151        let path = dir.path().join("vec_empty.snap");
1152
1153        let entries = vec![SnapEntry {
1154            key: "empty_vecs".into(),
1155            value: SnapValue::Vector {
1156                metric: 2, // inner product
1157                quantization: 2,
1158                connectivity: 8,
1159                expansion_add: 32,
1160                dim: 128,
1161                elements: vec![],
1162            },
1163            expire_ms: 5000,
1164        }];
1165
1166        {
1167            let mut writer = SnapshotWriter::create(&path, 0)?;
1168            for entry in &entries {
1169                writer.write_entry(entry)?;
1170            }
1171            writer.finish()?;
1172        }
1173
1174        let mut reader = SnapshotReader::open(&path)?;
1175        let got = reader.read_entry()?.unwrap();
1176        assert_eq!(entries[0], got);
1177        reader.verify_footer()?;
1178        Ok(())
1179    }
1180
1181    #[cfg(feature = "encryption")]
1182    mod encrypted {
1183        use super::*;
1184        use crate::encryption::EncryptionKey;
1185
1186        type Result = std::result::Result<(), Box<dyn std::error::Error>>;
1187
1188        fn test_key() -> EncryptionKey {
1189            EncryptionKey::from_bytes([0x42; 32])
1190        }
1191
1192        #[test]
1193        fn encrypted_snapshot_round_trip() -> Result {
1194            let dir = temp_dir();
1195            let path = dir.path().join("enc.snap");
1196            let key = test_key();
1197
1198            let entries = vec![
1199                SnapEntry {
1200                    key: "hello".into(),
1201                    value: SnapValue::String(Bytes::from("world")),
1202                    expire_ms: -1,
1203                },
1204                SnapEntry {
1205                    key: "ttl".into(),
1206                    value: SnapValue::String(Bytes::from("expiring")),
1207                    expire_ms: 5000,
1208                },
1209            ];
1210
1211            {
1212                let mut writer = SnapshotWriter::create_encrypted(&path, 7, key.clone())?;
1213                for entry in &entries {
1214                    writer.write_entry(entry)?;
1215                }
1216                writer.finish()?;
1217            }
1218
1219            let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1220            assert_eq!(reader.shard_id, 7);
1221            assert_eq!(reader.entry_count, 2);
1222
1223            let mut got = Vec::new();
1224            while let Some(entry) = reader.read_entry()? {
1225                got.push(entry);
1226            }
1227            assert_eq!(entries, got);
1228            reader.verify_footer()?;
1229            Ok(())
1230        }
1231
1232        #[test]
1233        fn encrypted_snapshot_wrong_key_fails() -> Result {
1234            let dir = temp_dir();
1235            let path = dir.path().join("enc_bad.snap");
1236            let key = test_key();
1237            let wrong_key = EncryptionKey::from_bytes([0xFF; 32]);
1238
1239            {
1240                let mut writer = SnapshotWriter::create_encrypted(&path, 0, key)?;
1241                writer.write_entry(&SnapEntry {
1242                    key: "k".into(),
1243                    value: SnapValue::String(Bytes::from("v")),
1244                    expire_ms: -1,
1245                })?;
1246                writer.finish()?;
1247            }
1248
1249            let mut reader = SnapshotReader::open_encrypted(&path, wrong_key)?;
1250            let err = reader.read_entry().unwrap_err();
1251            assert!(matches!(err, FormatError::DecryptionFailed));
1252            Ok(())
1253        }
1254
1255        #[test]
1256        fn v2_snapshot_readable_with_encryption_key() -> Result {
1257            let dir = temp_dir();
1258            let path = dir.path().join("v2.snap");
1259            let key = test_key();
1260
1261            {
1262                let mut writer = SnapshotWriter::create(&path, 0)?;
1263                writer.write_entry(&SnapEntry {
1264                    key: "k".into(),
1265                    value: SnapValue::String(Bytes::from("v")),
1266                    expire_ms: -1,
1267                })?;
1268                writer.finish()?;
1269            }
1270
1271            let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1272            let entry = reader.read_entry()?.unwrap();
1273            assert_eq!(entry.key, "k");
1274            reader.verify_footer()?;
1275            Ok(())
1276        }
1277
1278        #[test]
1279        fn v3_snapshot_without_key_returns_error() -> Result {
1280            let dir = temp_dir();
1281            let path = dir.path().join("v3_nokey.snap");
1282            let key = test_key();
1283
1284            {
1285                let mut writer = SnapshotWriter::create_encrypted(&path, 0, key)?;
1286                writer.write_entry(&SnapEntry {
1287                    key: "k".into(),
1288                    value: SnapValue::String(Bytes::from("v")),
1289                    expire_ms: -1,
1290                })?;
1291                writer.finish()?;
1292            }
1293
1294            let result = SnapshotReader::open(&path);
1295            assert!(matches!(result, Err(FormatError::EncryptionRequired)));
1296            Ok(())
1297        }
1298
1299        #[test]
1300        fn encrypted_snapshot_with_all_types() -> Result {
1301            let dir = temp_dir();
1302            let path = dir.path().join("enc_types.snap");
1303            let key = test_key();
1304
1305            let mut deque = VecDeque::new();
1306            deque.push_back(Bytes::from("a"));
1307            deque.push_back(Bytes::from("b"));
1308
1309            let mut hash = HashMap::new();
1310            hash.insert("f1".into(), Bytes::from("v1"));
1311
1312            let mut set = HashSet::new();
1313            set.insert("m1".into());
1314            set.insert("m2".into());
1315
1316            let entries = vec![
1317                SnapEntry {
1318                    key: "str".into(),
1319                    value: SnapValue::String(Bytes::from("val")),
1320                    expire_ms: -1,
1321                },
1322                SnapEntry {
1323                    key: "list".into(),
1324                    value: SnapValue::List(deque),
1325                    expire_ms: 1000,
1326                },
1327                SnapEntry {
1328                    key: "zset".into(),
1329                    value: SnapValue::SortedSet(vec![(1.0, "a".into()), (2.0, "b".into())]),
1330                    expire_ms: -1,
1331                },
1332                SnapEntry {
1333                    key: "hash".into(),
1334                    value: SnapValue::Hash(hash),
1335                    expire_ms: -1,
1336                },
1337                SnapEntry {
1338                    key: "set".into(),
1339                    value: SnapValue::Set(set),
1340                    expire_ms: -1,
1341                },
1342            ];
1343
1344            {
1345                let mut writer = SnapshotWriter::create_encrypted(&path, 0, key.clone())?;
1346                for entry in &entries {
1347                    writer.write_entry(entry)?;
1348                }
1349                writer.finish()?;
1350            }
1351
1352            let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1353            let mut got = Vec::new();
1354            while let Some(entry) = reader.read_entry()? {
1355                got.push(entry);
1356            }
1357            assert_eq!(entries, got);
1358            reader.verify_footer()?;
1359            Ok(())
1360        }
1361    }
1362}