1use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File, OpenOptions};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39#[cfg(feature = "vector")]
40const TYPE_VECTOR: u8 = 6;
41#[cfg(feature = "protobuf")]
42const TYPE_PROTO: u8 = 5;
43
44fn parse_utf8(bytes: Vec<u8>, field: &str) -> Result<String, FormatError> {
48 String::from_utf8(bytes).map_err(|_| {
49 FormatError::Io(io::Error::new(
50 io::ErrorKind::InvalidData,
51 format!("{field} is not valid utf-8"),
52 ))
53 })
54}
55
56#[cfg(feature = "encryption")]
58fn read_snap_string(r: &mut impl io::Read, field: &str) -> Result<String, FormatError> {
59 let bytes = format::read_bytes(r)?;
60 parse_utf8(bytes, field)
61}
62
63#[cfg(feature = "encryption")]
69fn parse_snap_value(r: &mut impl io::Read) -> Result<SnapValue, FormatError> {
70 let type_tag = format::read_u8(r)?;
71 match type_tag {
72 TYPE_STRING => {
73 let v = format::read_bytes(r)?;
74 Ok(SnapValue::String(Bytes::from(v)))
75 }
76 TYPE_LIST => {
77 let count = format::read_u32(r)?;
78 format::validate_collection_count(count, "list")?;
79 let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
80 for _ in 0..count {
81 deque.push_back(Bytes::from(format::read_bytes(r)?));
82 }
83 Ok(SnapValue::List(deque))
84 }
85 TYPE_SORTED_SET => {
86 let count = format::read_u32(r)?;
87 format::validate_collection_count(count, "sorted set")?;
88 let mut members = Vec::with_capacity(format::capped_capacity(count));
89 for _ in 0..count {
90 let score = format::read_f64(r)?;
91 let member = read_snap_string(r, "member")?;
92 members.push((score, member));
93 }
94 Ok(SnapValue::SortedSet(members))
95 }
96 TYPE_HASH => {
97 let count = format::read_u32(r)?;
98 format::validate_collection_count(count, "hash")?;
99 let mut map = HashMap::with_capacity(format::capped_capacity(count));
100 for _ in 0..count {
101 let field = read_snap_string(r, "hash field")?;
102 let value = format::read_bytes(r)?;
103 map.insert(field, Bytes::from(value));
104 }
105 Ok(SnapValue::Hash(map))
106 }
107 TYPE_SET => {
108 let count = format::read_u32(r)?;
109 format::validate_collection_count(count, "set")?;
110 let mut set = HashSet::with_capacity(format::capped_capacity(count));
111 for _ in 0..count {
112 let member = read_snap_string(r, "set member")?;
113 set.insert(member);
114 }
115 Ok(SnapValue::Set(set))
116 }
117 #[cfg(feature = "vector")]
118 TYPE_VECTOR => {
119 let metric = format::read_u8(r)?;
120 if metric > 2 {
121 return Err(FormatError::InvalidData(format!(
122 "unknown vector metric: {metric}"
123 )));
124 }
125 let quantization = format::read_u8(r)?;
126 if quantization > 2 {
127 return Err(FormatError::InvalidData(format!(
128 "unknown vector quantization: {quantization}"
129 )));
130 }
131 let connectivity = format::read_u32(r)?;
132 let expansion_add = format::read_u32(r)?;
133 let dim = format::read_u32(r)?;
134 if dim > format::MAX_PERSISTED_VECTOR_DIMS {
135 return Err(FormatError::InvalidData(format!(
136 "vector dimension {dim} exceeds max {}",
137 format::MAX_PERSISTED_VECTOR_DIMS
138 )));
139 }
140 let count = format::read_u32(r)?;
141 if count > format::MAX_PERSISTED_VECTOR_COUNT {
142 return Err(FormatError::InvalidData(format!(
143 "vector element count {count} exceeds max {}",
144 format::MAX_PERSISTED_VECTOR_COUNT
145 )));
146 }
147 format::validate_vector_total(dim, count)?;
148 let mut elements = Vec::with_capacity(format::capped_capacity(count));
149 for _ in 0..count {
150 let name = read_snap_string(r, "vector element name")?;
151 let mut vector = Vec::with_capacity(dim as usize);
152 for _ in 0..dim {
153 vector.push(format::read_f32(r)?);
154 }
155 elements.push((name, vector));
156 }
157 Ok(SnapValue::Vector {
158 metric,
159 quantization,
160 connectivity,
161 expansion_add,
162 dim,
163 elements,
164 })
165 }
166 #[cfg(feature = "protobuf")]
167 TYPE_PROTO => {
168 let type_name = read_snap_string(r, "proto type_name")?;
169 let data = format::read_bytes(r)?;
170 Ok(SnapValue::Proto {
171 type_name,
172 data: Bytes::from(data),
173 })
174 }
175 _ => Err(FormatError::UnknownTag(type_tag)),
176 }
177}
178
179#[derive(Debug, Clone, PartialEq)]
181pub enum SnapValue {
182 String(Bytes),
184 List(VecDeque<Bytes>),
186 SortedSet(Vec<(f64, String)>),
188 Hash(HashMap<String, Bytes>),
190 Set(HashSet<String>),
192 #[cfg(feature = "vector")]
194 Vector {
195 metric: u8,
196 quantization: u8,
197 connectivity: u32,
198 expansion_add: u32,
199 dim: u32,
200 elements: Vec<(String, Vec<f32>)>,
201 },
202 #[cfg(feature = "protobuf")]
204 Proto { type_name: String, data: Bytes },
205}
206
207#[derive(Debug, Clone, PartialEq)]
209pub struct SnapEntry {
210 pub key: String,
211 pub value: SnapValue,
212 pub expire_ms: i64,
214}
215
216impl SnapEntry {
217 fn estimated_size(&self) -> usize {
219 const LEN_PREFIX: usize = 4;
220
221 let key_size = LEN_PREFIX + self.key.len();
222 let value_size = match &self.value {
223 SnapValue::String(data) => 1 + LEN_PREFIX + data.len(),
224 SnapValue::List(deque) => {
225 let items: usize = deque.iter().map(|v| LEN_PREFIX + v.len()).sum();
226 1 + 4 + items
227 }
228 SnapValue::SortedSet(members) => {
229 let items: usize = members.iter().map(|(_, m)| 8 + LEN_PREFIX + m.len()).sum();
230 1 + 4 + items
231 }
232 SnapValue::Hash(map) => {
233 let items: usize = map
234 .iter()
235 .map(|(f, v)| LEN_PREFIX + f.len() + LEN_PREFIX + v.len())
236 .sum();
237 1 + 4 + items
238 }
239 SnapValue::Set(set) => {
240 let items: usize = set.iter().map(|m| LEN_PREFIX + m.len()).sum();
241 1 + 4 + items
242 }
243 #[cfg(feature = "vector")]
244 SnapValue::Vector { dim, elements, .. } => {
245 let items: usize = elements
246 .iter()
247 .map(|(name, _)| LEN_PREFIX + name.len() + (*dim as usize) * 4)
248 .sum();
249 1 + 2 + 4 + 4 + 4 + 4 + items
251 }
252 #[cfg(feature = "protobuf")]
253 SnapValue::Proto { type_name, data } => {
254 1 + LEN_PREFIX + type_name.len() + LEN_PREFIX + data.len()
255 }
256 };
257 key_size + value_size + 8
259 }
260}
261
262pub struct SnapshotWriter {
268 final_path: PathBuf,
269 tmp_path: PathBuf,
270 writer: BufWriter<File>,
271 hasher: crc32fast::Hasher,
273 count: u32,
274 finished: bool,
277 #[cfg(feature = "encryption")]
278 encryption_key: Option<crate::encryption::EncryptionKey>,
279}
280
281impl SnapshotWriter {
282 pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
285 let final_path = path.into();
286 let (tmp_path, writer) = Self::open_tmp(&final_path)?;
287 let mut writer = BufWriter::new(writer);
288
289 format::write_header(&mut writer, format::SNAP_MAGIC)?;
290 format::write_u16(&mut writer, shard_id)?;
291 format::write_u32(&mut writer, 0)?;
292
293 Ok(Self {
294 final_path,
295 tmp_path,
296 writer,
297 hasher: crc32fast::Hasher::new(),
298 count: 0,
299 finished: false,
300 #[cfg(feature = "encryption")]
301 encryption_key: None,
302 })
303 }
304
305 #[cfg(feature = "encryption")]
307 pub fn create_encrypted(
308 path: impl Into<PathBuf>,
309 shard_id: u16,
310 key: crate::encryption::EncryptionKey,
311 ) -> Result<Self, FormatError> {
312 let final_path = path.into();
313 let (tmp_path, file) = Self::open_tmp(&final_path)?;
314 let mut writer = BufWriter::new(file);
315
316 format::write_header_versioned(
317 &mut writer,
318 format::SNAP_MAGIC,
319 format::FORMAT_VERSION_ENCRYPTED,
320 )?;
321 format::write_u16(&mut writer, shard_id)?;
322 format::write_u32(&mut writer, 0)?;
323
324 Ok(Self {
325 final_path,
326 tmp_path,
327 writer,
328 hasher: crc32fast::Hasher::new(),
329 count: 0,
330 finished: false,
331 encryption_key: Some(key),
332 })
333 }
334
335 fn open_tmp(final_path: &Path) -> Result<(PathBuf, File), FormatError> {
337 let tmp_path = final_path.with_extension("snap.tmp");
338 let mut opts = OpenOptions::new();
339 opts.write(true).create(true).truncate(true);
340 #[cfg(unix)]
341 {
342 use std::os::unix::fs::OpenOptionsExt;
343 opts.mode(0o600);
344 }
345 let file = opts.open(&tmp_path)?;
346 Ok((tmp_path, file))
347 }
348
349 pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
354 let mut buf = Vec::with_capacity(entry.estimated_size());
355 format::write_bytes(&mut buf, entry.key.as_bytes())?;
356 match &entry.value {
357 SnapValue::String(data) => {
358 format::write_u8(&mut buf, TYPE_STRING)?;
359 format::write_bytes(&mut buf, data)?;
360 }
361 SnapValue::List(deque) => {
362 format::write_u8(&mut buf, TYPE_LIST)?;
363 format::write_len(&mut buf, deque.len())?;
364 for item in deque {
365 format::write_bytes(&mut buf, item)?;
366 }
367 }
368 SnapValue::SortedSet(members) => {
369 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
370 format::write_len(&mut buf, members.len())?;
371 for (score, member) in members {
372 format::write_f64(&mut buf, *score)?;
373 format::write_bytes(&mut buf, member.as_bytes())?;
374 }
375 }
376 SnapValue::Hash(map) => {
377 format::write_u8(&mut buf, TYPE_HASH)?;
378 format::write_len(&mut buf, map.len())?;
379 for (field, value) in map {
380 format::write_bytes(&mut buf, field.as_bytes())?;
381 format::write_bytes(&mut buf, value)?;
382 }
383 }
384 SnapValue::Set(set) => {
385 format::write_u8(&mut buf, TYPE_SET)?;
386 format::write_len(&mut buf, set.len())?;
387 for member in set {
388 format::write_bytes(&mut buf, member.as_bytes())?;
389 }
390 }
391 #[cfg(feature = "vector")]
392 SnapValue::Vector {
393 metric,
394 quantization,
395 connectivity,
396 expansion_add,
397 dim,
398 elements,
399 } => {
400 format::write_u8(&mut buf, TYPE_VECTOR)?;
401 format::write_u8(&mut buf, *metric)?;
402 format::write_u8(&mut buf, *quantization)?;
403 format::write_u32(&mut buf, *connectivity)?;
404 format::write_u32(&mut buf, *expansion_add)?;
405 format::write_u32(&mut buf, *dim)?;
406 format::write_len(&mut buf, elements.len())?;
407 for (name, vector) in elements {
408 format::write_bytes(&mut buf, name.as_bytes())?;
409 for &v in vector {
410 format::write_f32(&mut buf, v)?;
411 }
412 }
413 }
414 #[cfg(feature = "protobuf")]
415 SnapValue::Proto { type_name, data } => {
416 format::write_u8(&mut buf, TYPE_PROTO)?;
417 format::write_bytes(&mut buf, type_name.as_bytes())?;
418 format::write_bytes(&mut buf, data)?;
419 }
420 }
421 format::write_i64(&mut buf, entry.expire_ms)?;
422
423 #[cfg(feature = "encryption")]
424 if let Some(ref key) = self.encryption_key {
425 let (nonce, ciphertext) = crate::encryption::encrypt_record(key, &buf)?;
426 let ct_len = u32::try_from(ciphertext.len()).map_err(|_| {
427 io::Error::new(
428 io::ErrorKind::InvalidInput,
429 "encrypted record exceeds u32::MAX bytes",
430 )
431 })?;
432 self.hasher.update(&nonce);
434 let ct_len_bytes = ct_len.to_le_bytes();
435 self.hasher.update(&ct_len_bytes);
436 self.hasher.update(&ciphertext);
437 self.writer.write_all(&nonce)?;
438 format::write_u32(&mut self.writer, ct_len)?;
439 self.writer.write_all(&ciphertext)?;
440 self.count += 1;
441 return Ok(());
442 }
443
444 self.hasher.update(&buf);
445 self.writer.write_all(&buf)?;
446 self.count += 1;
447 Ok(())
448 }
449
450 pub fn finish(mut self) -> Result<(), FormatError> {
453 let checksum = self.hasher.clone().finalize();
455 format::write_u32(&mut self.writer, checksum)?;
456 self.writer.flush()?;
457 self.writer.get_ref().sync_all()?;
458
459 {
463 use std::io::{Seek, SeekFrom};
464 let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
465 file.seek(SeekFrom::Start(7))?;
467 format::write_u32(&mut file, self.count)?;
468 file.sync_all()?;
469 }
470
471 fs::rename(&self.tmp_path, &self.final_path)?;
473
474 if let Some(parent) = self.final_path.parent() {
476 if let Ok(dir) = File::open(parent) {
477 let _ = dir.sync_all();
478 }
479 }
480
481 self.finished = true;
482 Ok(())
483 }
484}
485
486impl Drop for SnapshotWriter {
487 fn drop(&mut self) {
488 if !self.finished {
489 let _ = fs::remove_file(&self.tmp_path);
491 }
492 }
493}
494
495pub struct SnapshotReader {
497 reader: BufReader<File>,
498 pub shard_id: u16,
499 pub entry_count: u32,
500 read_so_far: u32,
501 hasher: crc32fast::Hasher,
502 version: u8,
504 #[cfg(feature = "encryption")]
505 encryption_key: Option<crate::encryption::EncryptionKey>,
506}
507
508impl SnapshotReader {
509 pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
511 let file = File::open(path.as_ref())?;
512 let mut reader = BufReader::new(file);
513
514 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
515
516 if version == format::FORMAT_VERSION_ENCRYPTED {
517 return Err(FormatError::EncryptionRequired);
518 }
519
520 let shard_id = format::read_u16(&mut reader)?;
521 let entry_count = format::read_u32(&mut reader)?;
522
523 Ok(Self {
524 reader,
525 shard_id,
526 entry_count,
527 read_so_far: 0,
528 hasher: crc32fast::Hasher::new(),
529 version,
530 #[cfg(feature = "encryption")]
531 encryption_key: None,
532 })
533 }
534
535 #[cfg(feature = "encryption")]
539 pub fn open_encrypted(
540 path: impl AsRef<Path>,
541 key: crate::encryption::EncryptionKey,
542 ) -> Result<Self, FormatError> {
543 let file = File::open(path.as_ref())?;
544 let mut reader = BufReader::new(file);
545
546 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
547 let shard_id = format::read_u16(&mut reader)?;
548 let entry_count = format::read_u32(&mut reader)?;
549
550 Ok(Self {
551 reader,
552 shard_id,
553 entry_count,
554 read_so_far: 0,
555 hasher: crc32fast::Hasher::new(),
556 version,
557 encryption_key: Some(key),
558 })
559 }
560
561 pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
563 if self.read_so_far >= self.entry_count {
564 return Ok(None);
565 }
566
567 #[cfg(feature = "encryption")]
568 if self.version == format::FORMAT_VERSION_ENCRYPTED {
569 return self.read_encrypted_entry();
570 }
571
572 self.read_plaintext_entry()
573 }
574
575 fn read_plaintext_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
577 let mut buf = Vec::new();
578
579 let key_bytes = format::read_bytes(&mut self.reader)?;
580 format::write_bytes(&mut buf, &key_bytes)?;
581
582 let value = if self.version == 1 {
583 let value_bytes = format::read_bytes(&mut self.reader)?;
585 format::write_bytes(&mut buf, &value_bytes)?;
586 SnapValue::String(Bytes::from(value_bytes))
587 } else {
588 let type_tag = format::read_u8(&mut self.reader)?;
590 format::write_u8(&mut buf, type_tag)?;
591 match type_tag {
592 TYPE_STRING => {
593 let value_bytes = format::read_bytes(&mut self.reader)?;
594 format::write_bytes(&mut buf, &value_bytes)?;
595 SnapValue::String(Bytes::from(value_bytes))
596 }
597 TYPE_LIST => {
598 let count = format::read_u32(&mut self.reader)?;
599 format::validate_collection_count(count, "list")?;
600 format::write_u32(&mut buf, count)?;
601 let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
602 for _ in 0..count {
603 let item = format::read_bytes(&mut self.reader)?;
604 format::write_bytes(&mut buf, &item)?;
605 deque.push_back(Bytes::from(item));
606 }
607 SnapValue::List(deque)
608 }
609 TYPE_SORTED_SET => {
610 let count = format::read_u32(&mut self.reader)?;
611 format::validate_collection_count(count, "sorted set")?;
612 format::write_u32(&mut buf, count)?;
613 let mut members = Vec::with_capacity(format::capped_capacity(count));
614 for _ in 0..count {
615 let score = format::read_f64(&mut self.reader)?;
616 format::write_f64(&mut buf, score)?;
617 let member_bytes = format::read_bytes(&mut self.reader)?;
618 format::write_bytes(&mut buf, &member_bytes)?;
619 let member = parse_utf8(member_bytes, "member")?;
620 members.push((score, member));
621 }
622 SnapValue::SortedSet(members)
623 }
624 TYPE_HASH => {
625 let count = format::read_u32(&mut self.reader)?;
626 format::validate_collection_count(count, "hash")?;
627 format::write_u32(&mut buf, count)?;
628 let mut map = HashMap::with_capacity(format::capped_capacity(count));
629 for _ in 0..count {
630 let field_bytes = format::read_bytes(&mut self.reader)?;
631 format::write_bytes(&mut buf, &field_bytes)?;
632 let field = parse_utf8(field_bytes, "hash field")?;
633 let value_bytes = format::read_bytes(&mut self.reader)?;
634 format::write_bytes(&mut buf, &value_bytes)?;
635 map.insert(field, Bytes::from(value_bytes));
636 }
637 SnapValue::Hash(map)
638 }
639 TYPE_SET => {
640 let count = format::read_u32(&mut self.reader)?;
641 format::validate_collection_count(count, "set")?;
642 format::write_u32(&mut buf, count)?;
643 let mut set = HashSet::with_capacity(format::capped_capacity(count));
644 for _ in 0..count {
645 let member_bytes = format::read_bytes(&mut self.reader)?;
646 format::write_bytes(&mut buf, &member_bytes)?;
647 let member = parse_utf8(member_bytes, "set member")?;
648 set.insert(member);
649 }
650 SnapValue::Set(set)
651 }
652 #[cfg(feature = "vector")]
653 TYPE_VECTOR => {
654 let metric = format::read_u8(&mut self.reader)?;
655 if metric > 2 {
656 return Err(FormatError::InvalidData(format!(
657 "unknown vector metric: {metric}"
658 )));
659 }
660 format::write_u8(&mut buf, metric)?;
661 let quantization = format::read_u8(&mut self.reader)?;
662 if quantization > 2 {
663 return Err(FormatError::InvalidData(format!(
664 "unknown vector quantization: {quantization}"
665 )));
666 }
667 format::write_u8(&mut buf, quantization)?;
668 let connectivity = format::read_u32(&mut self.reader)?;
669 format::write_u32(&mut buf, connectivity)?;
670 let expansion_add = format::read_u32(&mut self.reader)?;
671 format::write_u32(&mut buf, expansion_add)?;
672 let dim = format::read_u32(&mut self.reader)?;
673 if dim > format::MAX_PERSISTED_VECTOR_DIMS {
674 return Err(FormatError::InvalidData(format!(
675 "vector dimension {dim} exceeds max {}",
676 format::MAX_PERSISTED_VECTOR_DIMS
677 )));
678 }
679 format::write_u32(&mut buf, dim)?;
680 let count = format::read_u32(&mut self.reader)?;
681 if count > format::MAX_PERSISTED_VECTOR_COUNT {
682 return Err(FormatError::InvalidData(format!(
683 "vector element count {count} exceeds max {}",
684 format::MAX_PERSISTED_VECTOR_COUNT
685 )));
686 }
687 format::validate_vector_total(dim, count)?;
688 format::write_u32(&mut buf, count)?;
689 let mut elements = Vec::with_capacity(format::capped_capacity(count));
690 for _ in 0..count {
691 let name_bytes = format::read_bytes(&mut self.reader)?;
692 format::write_bytes(&mut buf, &name_bytes)?;
693 let name = parse_utf8(name_bytes, "vector element name")?;
694 let mut vector = Vec::with_capacity(dim as usize);
695 for _ in 0..dim {
696 let v = format::read_f32(&mut self.reader)?;
697 format::write_f32(&mut buf, v)?;
698 vector.push(v);
699 }
700 elements.push((name, vector));
701 }
702 SnapValue::Vector {
703 metric,
704 quantization,
705 connectivity,
706 expansion_add,
707 dim,
708 elements,
709 }
710 }
711 #[cfg(feature = "protobuf")]
712 TYPE_PROTO => {
713 let type_name_bytes = format::read_bytes(&mut self.reader)?;
714 format::write_bytes(&mut buf, &type_name_bytes)?;
715 let type_name = parse_utf8(type_name_bytes, "proto type_name")?;
716 let data = format::read_bytes(&mut self.reader)?;
717 format::write_bytes(&mut buf, &data)?;
718 SnapValue::Proto {
719 type_name,
720 data: Bytes::from(data),
721 }
722 }
723 _ => {
724 return Err(FormatError::UnknownTag(type_tag));
725 }
726 }
727 };
728
729 let expire_ms = format::read_i64(&mut self.reader)?;
730 format::write_i64(&mut buf, expire_ms)?;
731 self.hasher.update(&buf);
732
733 let key = parse_utf8(key_bytes, "key")?;
734
735 self.read_so_far += 1;
736 Ok(Some(SnapEntry {
737 key,
738 value,
739 expire_ms,
740 }))
741 }
742
743 #[cfg(feature = "encryption")]
746 fn read_encrypted_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
747 use std::io::Read as _;
748
749 let key = self
750 .encryption_key
751 .as_ref()
752 .ok_or(FormatError::EncryptionRequired)?;
753
754 let mut nonce = [0u8; crate::encryption::NONCE_SIZE];
755 self.reader
756 .read_exact(&mut nonce)
757 .map_err(|e| match e.kind() {
758 io::ErrorKind::UnexpectedEof => FormatError::UnexpectedEof,
759 _ => FormatError::Io(e),
760 })?;
761
762 let ct_len = format::read_u32(&mut self.reader)? as usize;
763 if ct_len > format::MAX_FIELD_LEN {
764 return Err(FormatError::Io(io::Error::new(
765 io::ErrorKind::InvalidData,
766 format!("encrypted entry length {ct_len} exceeds maximum"),
767 )));
768 }
769
770 let mut ciphertext = vec![0u8; ct_len];
771 self.reader
772 .read_exact(&mut ciphertext)
773 .map_err(|e| match e.kind() {
774 io::ErrorKind::UnexpectedEof => FormatError::UnexpectedEof,
775 _ => FormatError::Io(e),
776 })?;
777
778 self.hasher.update(&nonce);
780 let ct_len_bytes = (ct_len as u32).to_le_bytes();
781 self.hasher.update(&ct_len_bytes);
782 self.hasher.update(&ciphertext);
783
784 let plaintext = crate::encryption::decrypt_record(key, &nonce, &ciphertext)?;
785
786 let mut cursor = io::Cursor::new(&plaintext);
787 let entry_key = read_snap_string(&mut cursor, "key")?;
788 let value = parse_snap_value(&mut cursor)?;
789 let expire_ms = format::read_i64(&mut cursor)?;
790
791 self.read_so_far += 1;
792 Ok(Some(SnapEntry {
793 key: entry_key,
794 value,
795 expire_ms,
796 }))
797 }
798
799 pub fn verify_footer(self) -> Result<(), FormatError> {
802 let expected = self.hasher.finalize();
803 let mut reader = self.reader;
804 let stored = format::read_u32(&mut reader)?;
805 format::verify_crc32_values(expected, stored)
806 }
807}
808
809pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
811 data_dir.join(format!("shard-{shard_id}.snap"))
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 type Result = std::result::Result<(), Box<dyn std::error::Error>>;
819
820 fn temp_dir() -> tempfile::TempDir {
821 tempfile::tempdir().expect("create temp dir")
822 }
823
824 #[test]
825 fn empty_snapshot_round_trip() -> Result {
826 let dir = temp_dir();
827 let path = dir.path().join("empty.snap");
828
829 {
830 let writer = SnapshotWriter::create(&path, 0)?;
831 writer.finish()?;
832 }
833
834 let reader = SnapshotReader::open(&path)?;
835 assert_eq!(reader.shard_id, 0);
836 assert_eq!(reader.entry_count, 0);
837 reader.verify_footer()?;
838 Ok(())
839 }
840
841 #[test]
842 fn entries_round_trip() -> Result {
843 let dir = temp_dir();
844 let path = dir.path().join("data.snap");
845
846 let entries = vec![
847 SnapEntry {
848 key: "hello".into(),
849 value: SnapValue::String(Bytes::from("world")),
850 expire_ms: -1,
851 },
852 SnapEntry {
853 key: "ttl".into(),
854 value: SnapValue::String(Bytes::from("expiring")),
855 expire_ms: 5000,
856 },
857 SnapEntry {
858 key: "empty".into(),
859 value: SnapValue::String(Bytes::new()),
860 expire_ms: -1,
861 },
862 ];
863
864 {
865 let mut writer = SnapshotWriter::create(&path, 7)?;
866 for entry in &entries {
867 writer.write_entry(entry)?;
868 }
869 writer.finish()?;
870 }
871
872 let mut reader = SnapshotReader::open(&path)?;
873 assert_eq!(reader.shard_id, 7);
874 assert_eq!(reader.entry_count, 3);
875
876 let mut got = Vec::new();
877 while let Some(entry) = reader.read_entry()? {
878 got.push(entry);
879 }
880 assert_eq!(entries, got);
881 reader.verify_footer()?;
882 Ok(())
883 }
884
885 #[test]
886 fn corrupt_footer_detected() -> Result {
887 let dir = temp_dir();
888 let path = dir.path().join("corrupt.snap");
889
890 {
891 let mut writer = SnapshotWriter::create(&path, 0)?;
892 writer.write_entry(&SnapEntry {
893 key: "k".into(),
894 value: SnapValue::String(Bytes::from("v")),
895 expire_ms: -1,
896 })?;
897 writer.finish()?;
898 }
899
900 let mut data = fs::read(&path)?;
902 let last = data.len() - 1;
903 data[last] ^= 0xFF;
904 fs::write(&path, &data)?;
905
906 let mut reader = SnapshotReader::open(&path)?;
907 reader.read_entry()?;
909 let err = reader.verify_footer().unwrap_err();
911 assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
912 Ok(())
913 }
914
915 #[test]
916 fn atomic_rename_prevents_partial_snapshots() -> Result {
917 let dir = temp_dir();
918 let path = dir.path().join("atomic.snap");
919
920 {
922 let mut writer = SnapshotWriter::create(&path, 0)?;
923 writer.write_entry(&SnapEntry {
924 key: "original".into(),
925 value: SnapValue::String(Bytes::from("data")),
926 expire_ms: -1,
927 })?;
928 writer.finish()?;
929 }
930
931 {
933 let mut writer = SnapshotWriter::create(&path, 0)?;
934 writer.write_entry(&SnapEntry {
935 key: "new".into(),
936 value: SnapValue::String(Bytes::from("partial")),
937 expire_ms: -1,
938 })?;
939 drop(writer);
941 }
942
943 let mut reader = SnapshotReader::open(&path)?;
945 let entry = reader.read_entry()?.unwrap();
946 assert_eq!(entry.key, "original");
947
948 let tmp = path.with_extension("snap.tmp");
950 assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
951 Ok(())
952 }
953
954 #[test]
955 fn ttl_entries_preserved() -> Result {
956 let dir = temp_dir();
957 let path = dir.path().join("ttl.snap");
958
959 let entry = SnapEntry {
960 key: "expires".into(),
961 value: SnapValue::String(Bytes::from("soon")),
962 expire_ms: 42_000,
963 };
964
965 {
966 let mut writer = SnapshotWriter::create(&path, 0)?;
967 writer.write_entry(&entry)?;
968 writer.finish()?;
969 }
970
971 let mut reader = SnapshotReader::open(&path)?;
972 let got = reader.read_entry()?.unwrap();
973 assert_eq!(got.expire_ms, 42_000);
974 reader.verify_footer()?;
975 Ok(())
976 }
977
978 #[test]
979 fn list_entries_round_trip() -> Result {
980 let dir = temp_dir();
981 let path = dir.path().join("list.snap");
982
983 let mut deque = VecDeque::new();
984 deque.push_back(Bytes::from("a"));
985 deque.push_back(Bytes::from("b"));
986 deque.push_back(Bytes::from("c"));
987
988 let entries = vec![
989 SnapEntry {
990 key: "mylist".into(),
991 value: SnapValue::List(deque),
992 expire_ms: -1,
993 },
994 SnapEntry {
995 key: "mystr".into(),
996 value: SnapValue::String(Bytes::from("val")),
997 expire_ms: 1000,
998 },
999 ];
1000
1001 {
1002 let mut writer = SnapshotWriter::create(&path, 0)?;
1003 for entry in &entries {
1004 writer.write_entry(entry)?;
1005 }
1006 writer.finish()?;
1007 }
1008
1009 let mut reader = SnapshotReader::open(&path)?;
1010 let mut got = Vec::new();
1011 while let Some(entry) = reader.read_entry()? {
1012 got.push(entry);
1013 }
1014 assert_eq!(entries, got);
1015 reader.verify_footer()?;
1016 Ok(())
1017 }
1018
1019 #[test]
1020 fn sorted_set_entries_round_trip() -> Result {
1021 let dir = temp_dir();
1022 let path = dir.path().join("zset.snap");
1023
1024 let entries = vec![
1025 SnapEntry {
1026 key: "board".into(),
1027 value: SnapValue::SortedSet(vec![
1028 (100.0, "alice".into()),
1029 (200.0, "bob".into()),
1030 (150.0, "charlie".into()),
1031 ]),
1032 expire_ms: -1,
1033 },
1034 SnapEntry {
1035 key: "mystr".into(),
1036 value: SnapValue::String(Bytes::from("val")),
1037 expire_ms: 1000,
1038 },
1039 ];
1040
1041 {
1042 let mut writer = SnapshotWriter::create(&path, 0)?;
1043 for entry in &entries {
1044 writer.write_entry(entry)?;
1045 }
1046 writer.finish()?;
1047 }
1048
1049 let mut reader = SnapshotReader::open(&path)?;
1050 let mut got = Vec::new();
1051 while let Some(entry) = reader.read_entry()? {
1052 got.push(entry);
1053 }
1054 assert_eq!(entries, got);
1055 reader.verify_footer()?;
1056 Ok(())
1057 }
1058
1059 #[test]
1060 fn snapshot_path_format() {
1061 let p = snapshot_path(Path::new("/data"), 5);
1062 assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
1063 }
1064
1065 #[test]
1066 fn truncated_snapshot_detected() -> Result {
1067 let dir = temp_dir();
1068 let path = dir.path().join("truncated.snap");
1069
1070 {
1072 let mut writer = SnapshotWriter::create(&path, 0)?;
1073 writer.write_entry(&SnapEntry {
1074 key: "a".into(),
1075 value: SnapValue::String(Bytes::from("1")),
1076 expire_ms: -1,
1077 })?;
1078 writer.write_entry(&SnapEntry {
1079 key: "b".into(),
1080 value: SnapValue::String(Bytes::from("2")),
1081 expire_ms: 5000,
1082 })?;
1083 writer.finish()?;
1084 }
1085
1086 let data = fs::read(&path)?;
1088 let truncated = &data[..data.len() - 20];
1089 fs::write(&path, truncated)?;
1090
1091 let mut reader = SnapshotReader::open(&path)?;
1092 assert_eq!(reader.entry_count, 2);
1093
1094 let first = reader.read_entry()?;
1096 assert!(first.is_some());
1097
1098 let err = reader.read_entry().unwrap_err();
1100 assert!(
1101 matches!(err, FormatError::UnexpectedEof | FormatError::Io(_)),
1102 "expected EOF error, got {err:?}"
1103 );
1104 Ok(())
1105 }
1106
1107 #[cfg(feature = "vector")]
1108 #[test]
1109 fn vector_entries_round_trip() -> Result {
1110 let dir = temp_dir();
1111 let path = dir.path().join("vec.snap");
1112
1113 let entries = vec![SnapEntry {
1114 key: "embeddings".into(),
1115 value: SnapValue::Vector {
1116 metric: 0,
1117 quantization: 0,
1118 connectivity: 16,
1119 expansion_add: 64,
1120 dim: 3,
1121 elements: vec![
1122 ("doc1".into(), vec![0.1, 0.2, 0.3]),
1123 ("doc2".into(), vec![0.4, 0.5, 0.6]),
1124 ],
1125 },
1126 expire_ms: -1,
1127 }];
1128
1129 {
1130 let mut writer = SnapshotWriter::create(&path, 0)?;
1131 for entry in &entries {
1132 writer.write_entry(entry)?;
1133 }
1134 writer.finish()?;
1135 }
1136
1137 let mut reader = SnapshotReader::open(&path)?;
1138 let mut got = Vec::new();
1139 while let Some(entry) = reader.read_entry()? {
1140 got.push(entry);
1141 }
1142 assert_eq!(entries, got);
1143 reader.verify_footer()?;
1144 Ok(())
1145 }
1146
1147 #[cfg(feature = "vector")]
1148 #[test]
1149 fn vector_empty_set_round_trip() -> Result {
1150 let dir = temp_dir();
1151 let path = dir.path().join("vec_empty.snap");
1152
1153 let entries = vec![SnapEntry {
1154 key: "empty_vecs".into(),
1155 value: SnapValue::Vector {
1156 metric: 2, quantization: 2,
1158 connectivity: 8,
1159 expansion_add: 32,
1160 dim: 128,
1161 elements: vec![],
1162 },
1163 expire_ms: 5000,
1164 }];
1165
1166 {
1167 let mut writer = SnapshotWriter::create(&path, 0)?;
1168 for entry in &entries {
1169 writer.write_entry(entry)?;
1170 }
1171 writer.finish()?;
1172 }
1173
1174 let mut reader = SnapshotReader::open(&path)?;
1175 let got = reader.read_entry()?.unwrap();
1176 assert_eq!(entries[0], got);
1177 reader.verify_footer()?;
1178 Ok(())
1179 }
1180
1181 #[cfg(feature = "encryption")]
1182 mod encrypted {
1183 use super::*;
1184 use crate::encryption::EncryptionKey;
1185
1186 type Result = std::result::Result<(), Box<dyn std::error::Error>>;
1187
1188 fn test_key() -> EncryptionKey {
1189 EncryptionKey::from_bytes([0x42; 32])
1190 }
1191
1192 #[test]
1193 fn encrypted_snapshot_round_trip() -> Result {
1194 let dir = temp_dir();
1195 let path = dir.path().join("enc.snap");
1196 let key = test_key();
1197
1198 let entries = vec![
1199 SnapEntry {
1200 key: "hello".into(),
1201 value: SnapValue::String(Bytes::from("world")),
1202 expire_ms: -1,
1203 },
1204 SnapEntry {
1205 key: "ttl".into(),
1206 value: SnapValue::String(Bytes::from("expiring")),
1207 expire_ms: 5000,
1208 },
1209 ];
1210
1211 {
1212 let mut writer = SnapshotWriter::create_encrypted(&path, 7, key.clone())?;
1213 for entry in &entries {
1214 writer.write_entry(entry)?;
1215 }
1216 writer.finish()?;
1217 }
1218
1219 let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1220 assert_eq!(reader.shard_id, 7);
1221 assert_eq!(reader.entry_count, 2);
1222
1223 let mut got = Vec::new();
1224 while let Some(entry) = reader.read_entry()? {
1225 got.push(entry);
1226 }
1227 assert_eq!(entries, got);
1228 reader.verify_footer()?;
1229 Ok(())
1230 }
1231
1232 #[test]
1233 fn encrypted_snapshot_wrong_key_fails() -> Result {
1234 let dir = temp_dir();
1235 let path = dir.path().join("enc_bad.snap");
1236 let key = test_key();
1237 let wrong_key = EncryptionKey::from_bytes([0xFF; 32]);
1238
1239 {
1240 let mut writer = SnapshotWriter::create_encrypted(&path, 0, key)?;
1241 writer.write_entry(&SnapEntry {
1242 key: "k".into(),
1243 value: SnapValue::String(Bytes::from("v")),
1244 expire_ms: -1,
1245 })?;
1246 writer.finish()?;
1247 }
1248
1249 let mut reader = SnapshotReader::open_encrypted(&path, wrong_key)?;
1250 let err = reader.read_entry().unwrap_err();
1251 assert!(matches!(err, FormatError::DecryptionFailed));
1252 Ok(())
1253 }
1254
1255 #[test]
1256 fn v2_snapshot_readable_with_encryption_key() -> Result {
1257 let dir = temp_dir();
1258 let path = dir.path().join("v2.snap");
1259 let key = test_key();
1260
1261 {
1262 let mut writer = SnapshotWriter::create(&path, 0)?;
1263 writer.write_entry(&SnapEntry {
1264 key: "k".into(),
1265 value: SnapValue::String(Bytes::from("v")),
1266 expire_ms: -1,
1267 })?;
1268 writer.finish()?;
1269 }
1270
1271 let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1272 let entry = reader.read_entry()?.unwrap();
1273 assert_eq!(entry.key, "k");
1274 reader.verify_footer()?;
1275 Ok(())
1276 }
1277
1278 #[test]
1279 fn v3_snapshot_without_key_returns_error() -> Result {
1280 let dir = temp_dir();
1281 let path = dir.path().join("v3_nokey.snap");
1282 let key = test_key();
1283
1284 {
1285 let mut writer = SnapshotWriter::create_encrypted(&path, 0, key)?;
1286 writer.write_entry(&SnapEntry {
1287 key: "k".into(),
1288 value: SnapValue::String(Bytes::from("v")),
1289 expire_ms: -1,
1290 })?;
1291 writer.finish()?;
1292 }
1293
1294 let result = SnapshotReader::open(&path);
1295 assert!(matches!(result, Err(FormatError::EncryptionRequired)));
1296 Ok(())
1297 }
1298
1299 #[test]
1300 fn encrypted_snapshot_with_all_types() -> Result {
1301 let dir = temp_dir();
1302 let path = dir.path().join("enc_types.snap");
1303 let key = test_key();
1304
1305 let mut deque = VecDeque::new();
1306 deque.push_back(Bytes::from("a"));
1307 deque.push_back(Bytes::from("b"));
1308
1309 let mut hash = HashMap::new();
1310 hash.insert("f1".into(), Bytes::from("v1"));
1311
1312 let mut set = HashSet::new();
1313 set.insert("m1".into());
1314 set.insert("m2".into());
1315
1316 let entries = vec![
1317 SnapEntry {
1318 key: "str".into(),
1319 value: SnapValue::String(Bytes::from("val")),
1320 expire_ms: -1,
1321 },
1322 SnapEntry {
1323 key: "list".into(),
1324 value: SnapValue::List(deque),
1325 expire_ms: 1000,
1326 },
1327 SnapEntry {
1328 key: "zset".into(),
1329 value: SnapValue::SortedSet(vec![(1.0, "a".into()), (2.0, "b".into())]),
1330 expire_ms: -1,
1331 },
1332 SnapEntry {
1333 key: "hash".into(),
1334 value: SnapValue::Hash(hash),
1335 expire_ms: -1,
1336 },
1337 SnapEntry {
1338 key: "set".into(),
1339 value: SnapValue::Set(set),
1340 expire_ms: -1,
1341 },
1342 ];
1343
1344 {
1345 let mut writer = SnapshotWriter::create_encrypted(&path, 0, key.clone())?;
1346 for entry in &entries {
1347 writer.write_entry(entry)?;
1348 }
1349 writer.finish()?;
1350 }
1351
1352 let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1353 let mut got = Vec::new();
1354 while let Some(entry) = reader.read_entry()? {
1355 got.push(entry);
1356 }
1357 assert_eq!(entries, got);
1358 reader.verify_footer()?;
1359 Ok(())
1360 }
1361 }
1362}