1use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File, OpenOptions};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39#[cfg(feature = "vector")]
40const TYPE_VECTOR: u8 = 6;
41#[cfg(feature = "protobuf")]
42const TYPE_PROTO: u8 = 5;
43
44fn parse_utf8(bytes: Vec<u8>, field: &str) -> Result<String, FormatError> {
48 String::from_utf8(bytes).map_err(|_| {
49 FormatError::Io(io::Error::new(
50 io::ErrorKind::InvalidData,
51 format!("{field} is not valid utf-8"),
52 ))
53 })
54}
55
56fn read_snap_string(r: &mut impl io::Read, field: &str) -> Result<String, FormatError> {
58 let bytes = format::read_bytes(r)?;
59 parse_utf8(bytes, field)
60}
61
62fn parse_snap_value(r: &mut impl io::Read) -> Result<SnapValue, FormatError> {
68 let type_tag = format::read_u8(r)?;
69 match type_tag {
70 TYPE_STRING => {
71 let v = format::read_bytes(r)?;
72 Ok(SnapValue::String(Bytes::from(v)))
73 }
74 TYPE_LIST => {
75 let count = format::read_u32(r)?;
76 format::validate_collection_count(count, "list")?;
77 let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
78 for _ in 0..count {
79 deque.push_back(Bytes::from(format::read_bytes(r)?));
80 }
81 Ok(SnapValue::List(deque))
82 }
83 TYPE_SORTED_SET => {
84 let count = format::read_u32(r)?;
85 format::validate_collection_count(count, "sorted set")?;
86 let mut members = Vec::with_capacity(format::capped_capacity(count));
87 for _ in 0..count {
88 let score = format::read_f64(r)?;
89 let member = read_snap_string(r, "member")?;
90 members.push((score, member));
91 }
92 Ok(SnapValue::SortedSet(members))
93 }
94 TYPE_HASH => {
95 let count = format::read_u32(r)?;
96 format::validate_collection_count(count, "hash")?;
97 let mut map = HashMap::with_capacity(format::capped_capacity(count));
98 for _ in 0..count {
99 let field = read_snap_string(r, "hash field")?;
100 let value = format::read_bytes(r)?;
101 map.insert(field, Bytes::from(value));
102 }
103 Ok(SnapValue::Hash(map))
104 }
105 TYPE_SET => {
106 let count = format::read_u32(r)?;
107 format::validate_collection_count(count, "set")?;
108 let mut set = HashSet::with_capacity(format::capped_capacity(count));
109 for _ in 0..count {
110 let member = read_snap_string(r, "set member")?;
111 set.insert(member);
112 }
113 Ok(SnapValue::Set(set))
114 }
115 #[cfg(feature = "vector")]
116 TYPE_VECTOR => {
117 let metric = format::read_u8(r)?;
118 if metric > 2 {
119 return Err(FormatError::InvalidData(format!(
120 "unknown vector metric: {metric}"
121 )));
122 }
123 let quantization = format::read_u8(r)?;
124 if quantization > 2 {
125 return Err(FormatError::InvalidData(format!(
126 "unknown vector quantization: {quantization}"
127 )));
128 }
129 let connectivity = format::read_u32(r)?;
130 let expansion_add = format::read_u32(r)?;
131 let dim = format::read_u32(r)?;
132 if dim > format::MAX_PERSISTED_VECTOR_DIMS {
133 return Err(FormatError::InvalidData(format!(
134 "vector dimension {dim} exceeds max {}",
135 format::MAX_PERSISTED_VECTOR_DIMS
136 )));
137 }
138 let count = format::read_u32(r)?;
139 if count > format::MAX_PERSISTED_VECTOR_COUNT {
140 return Err(FormatError::InvalidData(format!(
141 "vector element count {count} exceeds max {}",
142 format::MAX_PERSISTED_VECTOR_COUNT
143 )));
144 }
145 format::validate_vector_total(dim, count)?;
146 let mut elements = Vec::with_capacity(format::capped_capacity(count));
147 for _ in 0..count {
148 let name = read_snap_string(r, "vector element name")?;
149 let mut vector = Vec::with_capacity(dim as usize);
150 for _ in 0..dim {
151 vector.push(format::read_f32(r)?);
152 }
153 elements.push((name, vector));
154 }
155 Ok(SnapValue::Vector {
156 metric,
157 quantization,
158 connectivity,
159 expansion_add,
160 dim,
161 elements,
162 })
163 }
164 #[cfg(feature = "protobuf")]
165 TYPE_PROTO => {
166 let type_name = read_snap_string(r, "proto type_name")?;
167 let data = format::read_bytes(r)?;
168 Ok(SnapValue::Proto {
169 type_name,
170 data: Bytes::from(data),
171 })
172 }
173 _ => Err(FormatError::UnknownTag(type_tag)),
174 }
175}
176
177#[derive(Debug, Clone, PartialEq)]
179pub enum SnapValue {
180 String(Bytes),
182 List(VecDeque<Bytes>),
184 SortedSet(Vec<(f64, String)>),
186 Hash(HashMap<String, Bytes>),
188 Set(HashSet<String>),
190 #[cfg(feature = "vector")]
192 Vector {
193 metric: u8,
194 quantization: u8,
195 connectivity: u32,
196 expansion_add: u32,
197 dim: u32,
198 elements: Vec<(String, Vec<f32>)>,
199 },
200 #[cfg(feature = "protobuf")]
202 Proto { type_name: String, data: Bytes },
203}
204
205#[derive(Debug, Clone, PartialEq)]
207pub struct SnapEntry {
208 pub key: String,
209 pub value: SnapValue,
210 pub expire_ms: i64,
212}
213
214impl SnapEntry {
215 fn estimated_size(&self) -> usize {
217 const LEN_PREFIX: usize = 4;
218
219 let key_size = LEN_PREFIX + self.key.len();
220 let value_size = match &self.value {
221 SnapValue::String(data) => 1 + LEN_PREFIX + data.len(),
222 SnapValue::List(deque) => {
223 let items: usize = deque.iter().map(|v| LEN_PREFIX + v.len()).sum();
224 1 + 4 + items
225 }
226 SnapValue::SortedSet(members) => {
227 let items: usize = members.iter().map(|(_, m)| 8 + LEN_PREFIX + m.len()).sum();
228 1 + 4 + items
229 }
230 SnapValue::Hash(map) => {
231 let items: usize = map
232 .iter()
233 .map(|(f, v)| LEN_PREFIX + f.len() + LEN_PREFIX + v.len())
234 .sum();
235 1 + 4 + items
236 }
237 SnapValue::Set(set) => {
238 let items: usize = set.iter().map(|m| LEN_PREFIX + m.len()).sum();
239 1 + 4 + items
240 }
241 #[cfg(feature = "vector")]
242 SnapValue::Vector { dim, elements, .. } => {
243 let items: usize = elements
244 .iter()
245 .map(|(name, _)| LEN_PREFIX + name.len() + (*dim as usize) * 4)
246 .sum();
247 1 + 2 + 4 + 4 + 4 + 4 + items
249 }
250 #[cfg(feature = "protobuf")]
251 SnapValue::Proto { type_name, data } => {
252 1 + LEN_PREFIX + type_name.len() + LEN_PREFIX + data.len()
253 }
254 };
255 key_size + value_size + 8
257 }
258}
259
260pub fn serialize_snap_value(value: &SnapValue) -> Result<Vec<u8>, FormatError> {
266 let mut buf = Vec::new();
267 match value {
268 SnapValue::String(data) => {
269 format::write_u8(&mut buf, TYPE_STRING)?;
270 format::write_bytes(&mut buf, data)?;
271 }
272 SnapValue::List(deque) => {
273 format::write_u8(&mut buf, TYPE_LIST)?;
274 format::write_len(&mut buf, deque.len())?;
275 for item in deque {
276 format::write_bytes(&mut buf, item)?;
277 }
278 }
279 SnapValue::SortedSet(members) => {
280 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
281 format::write_len(&mut buf, members.len())?;
282 for (score, member) in members {
283 format::write_f64(&mut buf, *score)?;
284 format::write_bytes(&mut buf, member.as_bytes())?;
285 }
286 }
287 SnapValue::Hash(map) => {
288 format::write_u8(&mut buf, TYPE_HASH)?;
289 format::write_len(&mut buf, map.len())?;
290 for (field, value) in map {
291 format::write_bytes(&mut buf, field.as_bytes())?;
292 format::write_bytes(&mut buf, value)?;
293 }
294 }
295 SnapValue::Set(set) => {
296 format::write_u8(&mut buf, TYPE_SET)?;
297 format::write_len(&mut buf, set.len())?;
298 for member in set {
299 format::write_bytes(&mut buf, member.as_bytes())?;
300 }
301 }
302 #[cfg(feature = "vector")]
303 SnapValue::Vector {
304 metric,
305 quantization,
306 connectivity,
307 expansion_add,
308 dim,
309 elements,
310 } => {
311 format::write_u8(&mut buf, TYPE_VECTOR)?;
312 format::write_u8(&mut buf, *metric)?;
313 format::write_u8(&mut buf, *quantization)?;
314 format::write_u32(&mut buf, *connectivity)?;
315 format::write_u32(&mut buf, *expansion_add)?;
316 format::write_u32(&mut buf, *dim)?;
317 format::write_len(&mut buf, elements.len())?;
318 for (name, vector) in elements {
319 format::write_bytes(&mut buf, name.as_bytes())?;
320 for &v in vector {
321 format::write_f32(&mut buf, v)?;
322 }
323 }
324 }
325 #[cfg(feature = "protobuf")]
326 SnapValue::Proto { type_name, data } => {
327 format::write_u8(&mut buf, TYPE_PROTO)?;
328 format::write_bytes(&mut buf, type_name.as_bytes())?;
329 format::write_bytes(&mut buf, data)?;
330 }
331 }
332 Ok(buf)
333}
334
335pub fn deserialize_snap_value(data: &[u8]) -> Result<SnapValue, FormatError> {
337 let mut cursor = io::Cursor::new(data);
338 parse_snap_value(&mut cursor)
339}
340
341pub struct SnapshotWriter {
347 final_path: PathBuf,
348 tmp_path: PathBuf,
349 writer: BufWriter<File>,
350 hasher: crc32fast::Hasher,
352 count: u32,
353 finished: bool,
356 #[cfg(feature = "encryption")]
357 encryption_key: Option<crate::encryption::EncryptionKey>,
358}
359
360impl SnapshotWriter {
361 pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
364 let final_path = path.into();
365 let (tmp_path, writer) = Self::open_tmp(&final_path)?;
366 let mut writer = BufWriter::new(writer);
367
368 format::write_header(&mut writer, format::SNAP_MAGIC)?;
369 format::write_u16(&mut writer, shard_id)?;
370 format::write_u32(&mut writer, 0)?;
371
372 Ok(Self {
373 final_path,
374 tmp_path,
375 writer,
376 hasher: crc32fast::Hasher::new(),
377 count: 0,
378 finished: false,
379 #[cfg(feature = "encryption")]
380 encryption_key: None,
381 })
382 }
383
384 #[cfg(feature = "encryption")]
386 pub fn create_encrypted(
387 path: impl Into<PathBuf>,
388 shard_id: u16,
389 key: crate::encryption::EncryptionKey,
390 ) -> Result<Self, FormatError> {
391 let final_path = path.into();
392 let (tmp_path, file) = Self::open_tmp(&final_path)?;
393 let mut writer = BufWriter::new(file);
394
395 format::write_header_versioned(
396 &mut writer,
397 format::SNAP_MAGIC,
398 format::FORMAT_VERSION_ENCRYPTED,
399 )?;
400 format::write_u16(&mut writer, shard_id)?;
401 format::write_u32(&mut writer, 0)?;
402
403 Ok(Self {
404 final_path,
405 tmp_path,
406 writer,
407 hasher: crc32fast::Hasher::new(),
408 count: 0,
409 finished: false,
410 encryption_key: Some(key),
411 })
412 }
413
414 fn open_tmp(final_path: &Path) -> Result<(PathBuf, File), FormatError> {
416 let tmp_path = final_path.with_extension("snap.tmp");
417 let mut opts = OpenOptions::new();
418 opts.write(true).create(true).truncate(true);
419 #[cfg(unix)]
420 {
421 use std::os::unix::fs::OpenOptionsExt;
422 opts.mode(0o600);
423 }
424 let file = opts.open(&tmp_path)?;
425 Ok((tmp_path, file))
426 }
427
428 pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
433 let mut buf = Vec::with_capacity(entry.estimated_size());
434 format::write_bytes(&mut buf, entry.key.as_bytes())?;
435 match &entry.value {
436 SnapValue::String(data) => {
437 format::write_u8(&mut buf, TYPE_STRING)?;
438 format::write_bytes(&mut buf, data)?;
439 }
440 SnapValue::List(deque) => {
441 format::write_u8(&mut buf, TYPE_LIST)?;
442 format::write_len(&mut buf, deque.len())?;
443 for item in deque {
444 format::write_bytes(&mut buf, item)?;
445 }
446 }
447 SnapValue::SortedSet(members) => {
448 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
449 format::write_len(&mut buf, members.len())?;
450 for (score, member) in members {
451 format::write_f64(&mut buf, *score)?;
452 format::write_bytes(&mut buf, member.as_bytes())?;
453 }
454 }
455 SnapValue::Hash(map) => {
456 format::write_u8(&mut buf, TYPE_HASH)?;
457 format::write_len(&mut buf, map.len())?;
458 for (field, value) in map {
459 format::write_bytes(&mut buf, field.as_bytes())?;
460 format::write_bytes(&mut buf, value)?;
461 }
462 }
463 SnapValue::Set(set) => {
464 format::write_u8(&mut buf, TYPE_SET)?;
465 format::write_len(&mut buf, set.len())?;
466 for member in set {
467 format::write_bytes(&mut buf, member.as_bytes())?;
468 }
469 }
470 #[cfg(feature = "vector")]
471 SnapValue::Vector {
472 metric,
473 quantization,
474 connectivity,
475 expansion_add,
476 dim,
477 elements,
478 } => {
479 format::write_u8(&mut buf, TYPE_VECTOR)?;
480 format::write_u8(&mut buf, *metric)?;
481 format::write_u8(&mut buf, *quantization)?;
482 format::write_u32(&mut buf, *connectivity)?;
483 format::write_u32(&mut buf, *expansion_add)?;
484 format::write_u32(&mut buf, *dim)?;
485 format::write_len(&mut buf, elements.len())?;
486 for (name, vector) in elements {
487 format::write_bytes(&mut buf, name.as_bytes())?;
488 for &v in vector {
489 format::write_f32(&mut buf, v)?;
490 }
491 }
492 }
493 #[cfg(feature = "protobuf")]
494 SnapValue::Proto { type_name, data } => {
495 format::write_u8(&mut buf, TYPE_PROTO)?;
496 format::write_bytes(&mut buf, type_name.as_bytes())?;
497 format::write_bytes(&mut buf, data)?;
498 }
499 }
500 format::write_i64(&mut buf, entry.expire_ms)?;
501
502 #[cfg(feature = "encryption")]
503 if let Some(ref key) = self.encryption_key {
504 let (nonce, ciphertext) = crate::encryption::encrypt_record(key, &buf)?;
505 let ct_len = u32::try_from(ciphertext.len()).map_err(|_| {
506 io::Error::new(
507 io::ErrorKind::InvalidInput,
508 "encrypted record exceeds u32::MAX bytes",
509 )
510 })?;
511 self.hasher.update(&nonce);
513 let ct_len_bytes = ct_len.to_le_bytes();
514 self.hasher.update(&ct_len_bytes);
515 self.hasher.update(&ciphertext);
516 self.writer.write_all(&nonce)?;
517 format::write_u32(&mut self.writer, ct_len)?;
518 self.writer.write_all(&ciphertext)?;
519 self.count += 1;
520 return Ok(());
521 }
522
523 self.hasher.update(&buf);
524 self.writer.write_all(&buf)?;
525 self.count += 1;
526 Ok(())
527 }
528
529 pub fn finish(mut self) -> Result<(), FormatError> {
532 let checksum = self.hasher.clone().finalize();
534 format::write_u32(&mut self.writer, checksum)?;
535 self.writer.flush()?;
536 self.writer.get_ref().sync_all()?;
537
538 {
542 use std::io::{Seek, SeekFrom};
543 let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
544 file.seek(SeekFrom::Start(7))?;
546 format::write_u32(&mut file, self.count)?;
547 file.sync_all()?;
548 }
549
550 fs::rename(&self.tmp_path, &self.final_path)?;
552
553 if let Some(parent) = self.final_path.parent() {
555 if let Ok(dir) = File::open(parent) {
556 let _ = dir.sync_all();
557 }
558 }
559
560 self.finished = true;
561 Ok(())
562 }
563}
564
565impl Drop for SnapshotWriter {
566 fn drop(&mut self) {
567 if !self.finished {
568 let _ = fs::remove_file(&self.tmp_path);
570 }
571 }
572}
573
574pub struct SnapshotReader {
576 reader: BufReader<File>,
577 pub shard_id: u16,
578 pub entry_count: u32,
579 read_so_far: u32,
580 hasher: crc32fast::Hasher,
581 version: u8,
583 #[cfg(feature = "encryption")]
584 encryption_key: Option<crate::encryption::EncryptionKey>,
585}
586
587impl SnapshotReader {
588 pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
590 let file = File::open(path.as_ref())?;
591 let mut reader = BufReader::new(file);
592
593 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
594
595 if version == format::FORMAT_VERSION_ENCRYPTED {
596 return Err(FormatError::EncryptionRequired);
597 }
598
599 let shard_id = format::read_u16(&mut reader)?;
600 let entry_count = format::read_u32(&mut reader)?;
601
602 Ok(Self {
603 reader,
604 shard_id,
605 entry_count,
606 read_so_far: 0,
607 hasher: crc32fast::Hasher::new(),
608 version,
609 #[cfg(feature = "encryption")]
610 encryption_key: None,
611 })
612 }
613
614 #[cfg(feature = "encryption")]
618 pub fn open_encrypted(
619 path: impl AsRef<Path>,
620 key: crate::encryption::EncryptionKey,
621 ) -> Result<Self, FormatError> {
622 let file = File::open(path.as_ref())?;
623 let mut reader = BufReader::new(file);
624
625 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
626 let shard_id = format::read_u16(&mut reader)?;
627 let entry_count = format::read_u32(&mut reader)?;
628
629 Ok(Self {
630 reader,
631 shard_id,
632 entry_count,
633 read_so_far: 0,
634 hasher: crc32fast::Hasher::new(),
635 version,
636 encryption_key: Some(key),
637 })
638 }
639
640 pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
642 if self.read_so_far >= self.entry_count {
643 return Ok(None);
644 }
645
646 #[cfg(feature = "encryption")]
647 if self.version == format::FORMAT_VERSION_ENCRYPTED {
648 return self.read_encrypted_entry();
649 }
650
651 self.read_plaintext_entry()
652 }
653
654 fn read_plaintext_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
656 let mut buf = Vec::new();
657
658 let key_bytes = format::read_bytes(&mut self.reader)?;
659 format::write_bytes(&mut buf, &key_bytes)?;
660
661 let value = if self.version == 1 {
662 let value_bytes = format::read_bytes(&mut self.reader)?;
664 format::write_bytes(&mut buf, &value_bytes)?;
665 SnapValue::String(Bytes::from(value_bytes))
666 } else {
667 let type_tag = format::read_u8(&mut self.reader)?;
669 format::write_u8(&mut buf, type_tag)?;
670 match type_tag {
671 TYPE_STRING => {
672 let value_bytes = format::read_bytes(&mut self.reader)?;
673 format::write_bytes(&mut buf, &value_bytes)?;
674 SnapValue::String(Bytes::from(value_bytes))
675 }
676 TYPE_LIST => {
677 let count = format::read_u32(&mut self.reader)?;
678 format::validate_collection_count(count, "list")?;
679 format::write_u32(&mut buf, count)?;
680 let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
681 for _ in 0..count {
682 let item = format::read_bytes(&mut self.reader)?;
683 format::write_bytes(&mut buf, &item)?;
684 deque.push_back(Bytes::from(item));
685 }
686 SnapValue::List(deque)
687 }
688 TYPE_SORTED_SET => {
689 let count = format::read_u32(&mut self.reader)?;
690 format::validate_collection_count(count, "sorted set")?;
691 format::write_u32(&mut buf, count)?;
692 let mut members = Vec::with_capacity(format::capped_capacity(count));
693 for _ in 0..count {
694 let score = format::read_f64(&mut self.reader)?;
695 format::write_f64(&mut buf, score)?;
696 let member_bytes = format::read_bytes(&mut self.reader)?;
697 format::write_bytes(&mut buf, &member_bytes)?;
698 let member = parse_utf8(member_bytes, "member")?;
699 members.push((score, member));
700 }
701 SnapValue::SortedSet(members)
702 }
703 TYPE_HASH => {
704 let count = format::read_u32(&mut self.reader)?;
705 format::validate_collection_count(count, "hash")?;
706 format::write_u32(&mut buf, count)?;
707 let mut map = HashMap::with_capacity(format::capped_capacity(count));
708 for _ in 0..count {
709 let field_bytes = format::read_bytes(&mut self.reader)?;
710 format::write_bytes(&mut buf, &field_bytes)?;
711 let field = parse_utf8(field_bytes, "hash field")?;
712 let value_bytes = format::read_bytes(&mut self.reader)?;
713 format::write_bytes(&mut buf, &value_bytes)?;
714 map.insert(field, Bytes::from(value_bytes));
715 }
716 SnapValue::Hash(map)
717 }
718 TYPE_SET => {
719 let count = format::read_u32(&mut self.reader)?;
720 format::validate_collection_count(count, "set")?;
721 format::write_u32(&mut buf, count)?;
722 let mut set = HashSet::with_capacity(format::capped_capacity(count));
723 for _ in 0..count {
724 let member_bytes = format::read_bytes(&mut self.reader)?;
725 format::write_bytes(&mut buf, &member_bytes)?;
726 let member = parse_utf8(member_bytes, "set member")?;
727 set.insert(member);
728 }
729 SnapValue::Set(set)
730 }
731 #[cfg(feature = "vector")]
732 TYPE_VECTOR => {
733 let metric = format::read_u8(&mut self.reader)?;
734 if metric > 2 {
735 return Err(FormatError::InvalidData(format!(
736 "unknown vector metric: {metric}"
737 )));
738 }
739 format::write_u8(&mut buf, metric)?;
740 let quantization = format::read_u8(&mut self.reader)?;
741 if quantization > 2 {
742 return Err(FormatError::InvalidData(format!(
743 "unknown vector quantization: {quantization}"
744 )));
745 }
746 format::write_u8(&mut buf, quantization)?;
747 let connectivity = format::read_u32(&mut self.reader)?;
748 format::write_u32(&mut buf, connectivity)?;
749 let expansion_add = format::read_u32(&mut self.reader)?;
750 format::write_u32(&mut buf, expansion_add)?;
751 let dim = format::read_u32(&mut self.reader)?;
752 if dim > format::MAX_PERSISTED_VECTOR_DIMS {
753 return Err(FormatError::InvalidData(format!(
754 "vector dimension {dim} exceeds max {}",
755 format::MAX_PERSISTED_VECTOR_DIMS
756 )));
757 }
758 format::write_u32(&mut buf, dim)?;
759 let count = format::read_u32(&mut self.reader)?;
760 if count > format::MAX_PERSISTED_VECTOR_COUNT {
761 return Err(FormatError::InvalidData(format!(
762 "vector element count {count} exceeds max {}",
763 format::MAX_PERSISTED_VECTOR_COUNT
764 )));
765 }
766 format::validate_vector_total(dim, count)?;
767 format::write_u32(&mut buf, count)?;
768 let mut elements = Vec::with_capacity(format::capped_capacity(count));
769 for _ in 0..count {
770 let name_bytes = format::read_bytes(&mut self.reader)?;
771 format::write_bytes(&mut buf, &name_bytes)?;
772 let name = parse_utf8(name_bytes, "vector element name")?;
773 let mut vector = Vec::with_capacity(dim as usize);
774 for _ in 0..dim {
775 let v = format::read_f32(&mut self.reader)?;
776 format::write_f32(&mut buf, v)?;
777 vector.push(v);
778 }
779 elements.push((name, vector));
780 }
781 SnapValue::Vector {
782 metric,
783 quantization,
784 connectivity,
785 expansion_add,
786 dim,
787 elements,
788 }
789 }
790 #[cfg(feature = "protobuf")]
791 TYPE_PROTO => {
792 let type_name_bytes = format::read_bytes(&mut self.reader)?;
793 format::write_bytes(&mut buf, &type_name_bytes)?;
794 let type_name = parse_utf8(type_name_bytes, "proto type_name")?;
795 let data = format::read_bytes(&mut self.reader)?;
796 format::write_bytes(&mut buf, &data)?;
797 SnapValue::Proto {
798 type_name,
799 data: Bytes::from(data),
800 }
801 }
802 _ => {
803 return Err(FormatError::UnknownTag(type_tag));
804 }
805 }
806 };
807
808 let expire_ms = format::read_i64(&mut self.reader)?;
809 format::write_i64(&mut buf, expire_ms)?;
810 self.hasher.update(&buf);
811
812 let key = parse_utf8(key_bytes, "key")?;
813
814 self.read_so_far += 1;
815 Ok(Some(SnapEntry {
816 key,
817 value,
818 expire_ms,
819 }))
820 }
821
822 #[cfg(feature = "encryption")]
825 fn read_encrypted_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
826 use std::io::Read as _;
827
828 let key = self
829 .encryption_key
830 .as_ref()
831 .ok_or(FormatError::EncryptionRequired)?;
832
833 let mut nonce = [0u8; crate::encryption::NONCE_SIZE];
834 self.reader
835 .read_exact(&mut nonce)
836 .map_err(|e| match e.kind() {
837 io::ErrorKind::UnexpectedEof => FormatError::UnexpectedEof,
838 _ => FormatError::Io(e),
839 })?;
840
841 let ct_len = format::read_u32(&mut self.reader)? as usize;
842 if ct_len > format::MAX_FIELD_LEN {
843 return Err(FormatError::Io(io::Error::new(
844 io::ErrorKind::InvalidData,
845 format!("encrypted entry length {ct_len} exceeds maximum"),
846 )));
847 }
848
849 let mut ciphertext = vec![0u8; ct_len];
850 self.reader
851 .read_exact(&mut ciphertext)
852 .map_err(|e| match e.kind() {
853 io::ErrorKind::UnexpectedEof => FormatError::UnexpectedEof,
854 _ => FormatError::Io(e),
855 })?;
856
857 self.hasher.update(&nonce);
859 let ct_len_bytes = (ct_len as u32).to_le_bytes();
860 self.hasher.update(&ct_len_bytes);
861 self.hasher.update(&ciphertext);
862
863 let plaintext = crate::encryption::decrypt_record(key, &nonce, &ciphertext)?;
864
865 let mut cursor = io::Cursor::new(&plaintext);
866 let entry_key = read_snap_string(&mut cursor, "key")?;
867 let value = parse_snap_value(&mut cursor)?;
868 let expire_ms = format::read_i64(&mut cursor)?;
869
870 self.read_so_far += 1;
871 Ok(Some(SnapEntry {
872 key: entry_key,
873 value,
874 expire_ms,
875 }))
876 }
877
878 pub fn verify_footer(self) -> Result<(), FormatError> {
881 let expected = self.hasher.finalize();
882 let mut reader = self.reader;
883 let stored = format::read_u32(&mut reader)?;
884 format::verify_crc32_values(expected, stored)
885 }
886}
887
888pub fn write_snapshot_bytes(shard_id: u16, entries: &[SnapEntry]) -> Result<Vec<u8>, FormatError> {
895 use std::io::{Seek, SeekFrom, Write as _};
896
897 let mut buf = io::Cursor::new(Vec::<u8>::new());
898 let mut hasher = crc32fast::Hasher::new();
899
900 format::write_header(&mut buf, format::SNAP_MAGIC)?;
901 format::write_u16(&mut buf, shard_id)?;
902 let count_pos = buf.position();
904 format::write_u32(&mut buf, 0u32)?;
905
906 let mut count = 0u32;
907 for entry in entries {
908 let entry_bytes = serialize_entry(entry)?;
909 hasher.update(&entry_bytes);
910 buf.write_all(&entry_bytes)?;
911 count += 1;
912 }
913
914 let end_pos = buf.position();
916 buf.seek(SeekFrom::Start(count_pos))?;
917 format::write_u32(&mut buf, count)?;
918 buf.seek(SeekFrom::Start(end_pos))?;
919
920 let checksum = hasher.finalize();
922 format::write_u32(&mut buf, checksum)?;
923
924 Ok(buf.into_inner())
925}
926
927pub fn read_snapshot_from_bytes(data: &[u8]) -> Result<(u16, Vec<SnapEntry>), FormatError> {
932 let mut r = io::Cursor::new(data);
933 let mut hasher = crc32fast::Hasher::new();
934
935 let version = format::read_header(&mut r, format::SNAP_MAGIC)?;
936 if version != format::FORMAT_VERSION {
937 return Err(FormatError::UnsupportedVersion(version));
938 }
939 let shard_id = format::read_u16(&mut r)?;
940 let entry_count = format::read_u32(&mut r)?;
941
942 let mut entries = Vec::with_capacity(entry_count.min(65536) as usize);
943 for _ in 0..entry_count {
944 let (entry, entry_bytes) = read_entry_with_bytes(&mut r)?;
945 hasher.update(&entry_bytes);
946 entries.push(entry);
947 }
948
949 let expected = hasher.finalize();
951 let stored = format::read_u32(&mut r)?;
952 format::verify_crc32_values(expected, stored)?;
953
954 Ok((shard_id, entries))
955}
956
957fn serialize_entry(entry: &SnapEntry) -> Result<Vec<u8>, FormatError> {
961 let mut buf = Vec::with_capacity(entry.estimated_size());
962 format::write_bytes(&mut buf, entry.key.as_bytes())?;
963 match &entry.value {
964 SnapValue::String(data) => {
965 format::write_u8(&mut buf, TYPE_STRING)?;
966 format::write_bytes(&mut buf, data)?;
967 }
968 SnapValue::List(deque) => {
969 format::write_u8(&mut buf, TYPE_LIST)?;
970 format::write_len(&mut buf, deque.len())?;
971 for item in deque {
972 format::write_bytes(&mut buf, item)?;
973 }
974 }
975 SnapValue::SortedSet(members) => {
976 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
977 format::write_len(&mut buf, members.len())?;
978 for (score, member) in members {
979 format::write_f64(&mut buf, *score)?;
980 format::write_bytes(&mut buf, member.as_bytes())?;
981 }
982 }
983 SnapValue::Hash(map) => {
984 format::write_u8(&mut buf, TYPE_HASH)?;
985 format::write_len(&mut buf, map.len())?;
986 for (field, value) in map {
987 format::write_bytes(&mut buf, field.as_bytes())?;
988 format::write_bytes(&mut buf, value)?;
989 }
990 }
991 SnapValue::Set(set) => {
992 format::write_u8(&mut buf, TYPE_SET)?;
993 format::write_len(&mut buf, set.len())?;
994 for member in set {
995 format::write_bytes(&mut buf, member.as_bytes())?;
996 }
997 }
998 #[cfg(feature = "vector")]
999 SnapValue::Vector {
1000 metric,
1001 quantization,
1002 connectivity,
1003 expansion_add,
1004 dim,
1005 elements,
1006 } => {
1007 format::write_u8(&mut buf, TYPE_VECTOR)?;
1008 format::write_u8(&mut buf, *metric)?;
1009 format::write_u8(&mut buf, *quantization)?;
1010 format::write_u32(&mut buf, *connectivity)?;
1011 format::write_u32(&mut buf, *expansion_add)?;
1012 format::write_u32(&mut buf, *dim)?;
1013 format::write_len(&mut buf, elements.len())?;
1014 for (name, vector) in elements {
1015 format::write_bytes(&mut buf, name.as_bytes())?;
1016 for &v in vector {
1017 format::write_f32(&mut buf, v)?;
1018 }
1019 }
1020 }
1021 #[cfg(feature = "protobuf")]
1022 SnapValue::Proto { type_name, data } => {
1023 format::write_u8(&mut buf, TYPE_PROTO)?;
1024 format::write_bytes(&mut buf, type_name.as_bytes())?;
1025 format::write_bytes(&mut buf, data)?;
1026 }
1027 }
1028 format::write_i64(&mut buf, entry.expire_ms)?;
1029 Ok(buf)
1030}
1031
1032fn read_entry_with_bytes(r: &mut io::Cursor<&[u8]>) -> Result<(SnapEntry, Vec<u8>), FormatError> {
1035 let mut entry_bytes = Vec::new();
1036
1037 let key_bytes = format::read_bytes(r)?;
1038 format::write_bytes(&mut entry_bytes, &key_bytes)?;
1039 let key = parse_utf8(key_bytes, "key")?;
1040
1041 let type_tag = format::read_u8(r)?;
1042 format::write_u8(&mut entry_bytes, type_tag)?;
1043
1044 let value = match type_tag {
1045 TYPE_STRING => {
1046 let v = format::read_bytes(r)?;
1047 format::write_bytes(&mut entry_bytes, &v)?;
1048 SnapValue::String(Bytes::from(v))
1049 }
1050 TYPE_LIST => {
1051 let count = format::read_u32(r)?;
1052 format::validate_collection_count(count, "list")?;
1053 format::write_u32(&mut entry_bytes, count)?;
1054 let mut deque = VecDeque::with_capacity(format::capped_capacity(count));
1055 for _ in 0..count {
1056 let item = format::read_bytes(r)?;
1057 format::write_bytes(&mut entry_bytes, &item)?;
1058 deque.push_back(Bytes::from(item));
1059 }
1060 SnapValue::List(deque)
1061 }
1062 TYPE_SORTED_SET => {
1063 let count = format::read_u32(r)?;
1064 format::validate_collection_count(count, "sorted set")?;
1065 format::write_u32(&mut entry_bytes, count)?;
1066 let mut members = Vec::with_capacity(format::capped_capacity(count));
1067 for _ in 0..count {
1068 let score = format::read_f64(r)?;
1069 format::write_f64(&mut entry_bytes, score)?;
1070 let mb = format::read_bytes(r)?;
1071 format::write_bytes(&mut entry_bytes, &mb)?;
1072 members.push((score, parse_utf8(mb, "member")?));
1073 }
1074 SnapValue::SortedSet(members)
1075 }
1076 TYPE_HASH => {
1077 let count = format::read_u32(r)?;
1078 format::validate_collection_count(count, "hash")?;
1079 format::write_u32(&mut entry_bytes, count)?;
1080 let mut map = HashMap::with_capacity(format::capped_capacity(count));
1081 for _ in 0..count {
1082 let fb = format::read_bytes(r)?;
1083 format::write_bytes(&mut entry_bytes, &fb)?;
1084 let field = parse_utf8(fb, "hash field")?;
1085 let vb = format::read_bytes(r)?;
1086 format::write_bytes(&mut entry_bytes, &vb)?;
1087 map.insert(field, Bytes::from(vb));
1088 }
1089 SnapValue::Hash(map)
1090 }
1091 TYPE_SET => {
1092 let count = format::read_u32(r)?;
1093 format::validate_collection_count(count, "set")?;
1094 format::write_u32(&mut entry_bytes, count)?;
1095 let mut set = HashSet::with_capacity(format::capped_capacity(count));
1096 for _ in 0..count {
1097 let mb = format::read_bytes(r)?;
1098 format::write_bytes(&mut entry_bytes, &mb)?;
1099 set.insert(parse_utf8(mb, "set member")?);
1100 }
1101 SnapValue::Set(set)
1102 }
1103 #[cfg(feature = "vector")]
1104 TYPE_VECTOR => {
1105 let metric = format::read_u8(r)?;
1106 format::write_u8(&mut entry_bytes, metric)?;
1107 let quantization = format::read_u8(r)?;
1108 format::write_u8(&mut entry_bytes, quantization)?;
1109 let connectivity = format::read_u32(r)?;
1110 format::write_u32(&mut entry_bytes, connectivity)?;
1111 let expansion_add = format::read_u32(r)?;
1112 format::write_u32(&mut entry_bytes, expansion_add)?;
1113 let dim = format::read_u32(r)?;
1114 format::write_u32(&mut entry_bytes, dim)?;
1115 let count = format::read_u32(r)?;
1116 format::write_u32(&mut entry_bytes, count)?;
1117 let mut elements = Vec::with_capacity(format::capped_capacity(count));
1118 for _ in 0..count {
1119 let nb = format::read_bytes(r)?;
1120 format::write_bytes(&mut entry_bytes, &nb)?;
1121 let name = parse_utf8(nb, "vector element name")?;
1122 let mut vector = Vec::with_capacity(dim as usize);
1123 for _ in 0..dim {
1124 let v = format::read_f32(r)?;
1125 format::write_f32(&mut entry_bytes, v)?;
1126 vector.push(v);
1127 }
1128 elements.push((name, vector));
1129 }
1130 SnapValue::Vector {
1131 metric,
1132 quantization,
1133 connectivity,
1134 expansion_add,
1135 dim,
1136 elements,
1137 }
1138 }
1139 #[cfg(feature = "protobuf")]
1140 TYPE_PROTO => {
1141 let tn_bytes = format::read_bytes(r)?;
1142 format::write_bytes(&mut entry_bytes, &tn_bytes)?;
1143 let type_name = parse_utf8(tn_bytes, "proto type_name")?;
1144 let data = format::read_bytes(r)?;
1145 format::write_bytes(&mut entry_bytes, &data)?;
1146 SnapValue::Proto {
1147 type_name,
1148 data: Bytes::from(data),
1149 }
1150 }
1151 _ => return Err(FormatError::UnknownTag(type_tag)),
1152 };
1153
1154 let expire_ms = format::read_i64(r)?;
1155 format::write_i64(&mut entry_bytes, expire_ms)?;
1156
1157 Ok((
1158 SnapEntry {
1159 key,
1160 value,
1161 expire_ms,
1162 },
1163 entry_bytes,
1164 ))
1165}
1166
1167pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
1169 data_dir.join(format!("shard-{shard_id}.snap"))
1170}
1171
1172#[cfg(test)]
1173mod tests {
1174 use super::*;
1175
1176 type Result = std::result::Result<(), Box<dyn std::error::Error>>;
1177
1178 fn temp_dir() -> tempfile::TempDir {
1179 tempfile::tempdir().expect("create temp dir")
1180 }
1181
1182 #[test]
1183 fn empty_snapshot_round_trip() -> Result {
1184 let dir = temp_dir();
1185 let path = dir.path().join("empty.snap");
1186
1187 {
1188 let writer = SnapshotWriter::create(&path, 0)?;
1189 writer.finish()?;
1190 }
1191
1192 let reader = SnapshotReader::open(&path)?;
1193 assert_eq!(reader.shard_id, 0);
1194 assert_eq!(reader.entry_count, 0);
1195 reader.verify_footer()?;
1196 Ok(())
1197 }
1198
1199 #[test]
1200 fn entries_round_trip() -> Result {
1201 let dir = temp_dir();
1202 let path = dir.path().join("data.snap");
1203
1204 let entries = vec![
1205 SnapEntry {
1206 key: "hello".into(),
1207 value: SnapValue::String(Bytes::from("world")),
1208 expire_ms: -1,
1209 },
1210 SnapEntry {
1211 key: "ttl".into(),
1212 value: SnapValue::String(Bytes::from("expiring")),
1213 expire_ms: 5000,
1214 },
1215 SnapEntry {
1216 key: "empty".into(),
1217 value: SnapValue::String(Bytes::new()),
1218 expire_ms: -1,
1219 },
1220 ];
1221
1222 {
1223 let mut writer = SnapshotWriter::create(&path, 7)?;
1224 for entry in &entries {
1225 writer.write_entry(entry)?;
1226 }
1227 writer.finish()?;
1228 }
1229
1230 let mut reader = SnapshotReader::open(&path)?;
1231 assert_eq!(reader.shard_id, 7);
1232 assert_eq!(reader.entry_count, 3);
1233
1234 let mut got = Vec::new();
1235 while let Some(entry) = reader.read_entry()? {
1236 got.push(entry);
1237 }
1238 assert_eq!(entries, got);
1239 reader.verify_footer()?;
1240 Ok(())
1241 }
1242
1243 #[test]
1244 fn corrupt_footer_detected() -> Result {
1245 let dir = temp_dir();
1246 let path = dir.path().join("corrupt.snap");
1247
1248 {
1249 let mut writer = SnapshotWriter::create(&path, 0)?;
1250 writer.write_entry(&SnapEntry {
1251 key: "k".into(),
1252 value: SnapValue::String(Bytes::from("v")),
1253 expire_ms: -1,
1254 })?;
1255 writer.finish()?;
1256 }
1257
1258 let mut data = fs::read(&path)?;
1260 let last = data.len() - 1;
1261 data[last] ^= 0xFF;
1262 fs::write(&path, &data)?;
1263
1264 let mut reader = SnapshotReader::open(&path)?;
1265 reader.read_entry()?;
1267 let err = reader.verify_footer().unwrap_err();
1269 assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
1270 Ok(())
1271 }
1272
1273 #[test]
1274 fn atomic_rename_prevents_partial_snapshots() -> Result {
1275 let dir = temp_dir();
1276 let path = dir.path().join("atomic.snap");
1277
1278 {
1280 let mut writer = SnapshotWriter::create(&path, 0)?;
1281 writer.write_entry(&SnapEntry {
1282 key: "original".into(),
1283 value: SnapValue::String(Bytes::from("data")),
1284 expire_ms: -1,
1285 })?;
1286 writer.finish()?;
1287 }
1288
1289 {
1291 let mut writer = SnapshotWriter::create(&path, 0)?;
1292 writer.write_entry(&SnapEntry {
1293 key: "new".into(),
1294 value: SnapValue::String(Bytes::from("partial")),
1295 expire_ms: -1,
1296 })?;
1297 drop(writer);
1299 }
1300
1301 let mut reader = SnapshotReader::open(&path)?;
1303 let entry = reader.read_entry()?.unwrap();
1304 assert_eq!(entry.key, "original");
1305
1306 let tmp = path.with_extension("snap.tmp");
1308 assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
1309 Ok(())
1310 }
1311
1312 #[test]
1313 fn ttl_entries_preserved() -> Result {
1314 let dir = temp_dir();
1315 let path = dir.path().join("ttl.snap");
1316
1317 let entry = SnapEntry {
1318 key: "expires".into(),
1319 value: SnapValue::String(Bytes::from("soon")),
1320 expire_ms: 42_000,
1321 };
1322
1323 {
1324 let mut writer = SnapshotWriter::create(&path, 0)?;
1325 writer.write_entry(&entry)?;
1326 writer.finish()?;
1327 }
1328
1329 let mut reader = SnapshotReader::open(&path)?;
1330 let got = reader.read_entry()?.unwrap();
1331 assert_eq!(got.expire_ms, 42_000);
1332 reader.verify_footer()?;
1333 Ok(())
1334 }
1335
1336 #[test]
1337 fn list_entries_round_trip() -> Result {
1338 let dir = temp_dir();
1339 let path = dir.path().join("list.snap");
1340
1341 let mut deque = VecDeque::new();
1342 deque.push_back(Bytes::from("a"));
1343 deque.push_back(Bytes::from("b"));
1344 deque.push_back(Bytes::from("c"));
1345
1346 let entries = vec![
1347 SnapEntry {
1348 key: "mylist".into(),
1349 value: SnapValue::List(deque),
1350 expire_ms: -1,
1351 },
1352 SnapEntry {
1353 key: "mystr".into(),
1354 value: SnapValue::String(Bytes::from("val")),
1355 expire_ms: 1000,
1356 },
1357 ];
1358
1359 {
1360 let mut writer = SnapshotWriter::create(&path, 0)?;
1361 for entry in &entries {
1362 writer.write_entry(entry)?;
1363 }
1364 writer.finish()?;
1365 }
1366
1367 let mut reader = SnapshotReader::open(&path)?;
1368 let mut got = Vec::new();
1369 while let Some(entry) = reader.read_entry()? {
1370 got.push(entry);
1371 }
1372 assert_eq!(entries, got);
1373 reader.verify_footer()?;
1374 Ok(())
1375 }
1376
1377 #[test]
1378 fn sorted_set_entries_round_trip() -> Result {
1379 let dir = temp_dir();
1380 let path = dir.path().join("zset.snap");
1381
1382 let entries = vec![
1383 SnapEntry {
1384 key: "board".into(),
1385 value: SnapValue::SortedSet(vec![
1386 (100.0, "alice".into()),
1387 (200.0, "bob".into()),
1388 (150.0, "charlie".into()),
1389 ]),
1390 expire_ms: -1,
1391 },
1392 SnapEntry {
1393 key: "mystr".into(),
1394 value: SnapValue::String(Bytes::from("val")),
1395 expire_ms: 1000,
1396 },
1397 ];
1398
1399 {
1400 let mut writer = SnapshotWriter::create(&path, 0)?;
1401 for entry in &entries {
1402 writer.write_entry(entry)?;
1403 }
1404 writer.finish()?;
1405 }
1406
1407 let mut reader = SnapshotReader::open(&path)?;
1408 let mut got = Vec::new();
1409 while let Some(entry) = reader.read_entry()? {
1410 got.push(entry);
1411 }
1412 assert_eq!(entries, got);
1413 reader.verify_footer()?;
1414 Ok(())
1415 }
1416
1417 #[test]
1418 fn snapshot_path_format() {
1419 let p = snapshot_path(Path::new("/data"), 5);
1420 assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
1421 }
1422
1423 #[test]
1424 fn truncated_snapshot_detected() -> Result {
1425 let dir = temp_dir();
1426 let path = dir.path().join("truncated.snap");
1427
1428 {
1430 let mut writer = SnapshotWriter::create(&path, 0)?;
1431 writer.write_entry(&SnapEntry {
1432 key: "a".into(),
1433 value: SnapValue::String(Bytes::from("1")),
1434 expire_ms: -1,
1435 })?;
1436 writer.write_entry(&SnapEntry {
1437 key: "b".into(),
1438 value: SnapValue::String(Bytes::from("2")),
1439 expire_ms: 5000,
1440 })?;
1441 writer.finish()?;
1442 }
1443
1444 let data = fs::read(&path)?;
1446 let truncated = &data[..data.len() - 20];
1447 fs::write(&path, truncated)?;
1448
1449 let mut reader = SnapshotReader::open(&path)?;
1450 assert_eq!(reader.entry_count, 2);
1451
1452 let first = reader.read_entry()?;
1454 assert!(first.is_some());
1455
1456 let err = reader.read_entry().unwrap_err();
1458 assert!(
1459 matches!(err, FormatError::UnexpectedEof | FormatError::Io(_)),
1460 "expected EOF error, got {err:?}"
1461 );
1462 Ok(())
1463 }
1464
1465 #[cfg(feature = "vector")]
1466 #[test]
1467 fn vector_entries_round_trip() -> Result {
1468 let dir = temp_dir();
1469 let path = dir.path().join("vec.snap");
1470
1471 let entries = vec![SnapEntry {
1472 key: "embeddings".into(),
1473 value: SnapValue::Vector {
1474 metric: 0,
1475 quantization: 0,
1476 connectivity: 16,
1477 expansion_add: 64,
1478 dim: 3,
1479 elements: vec![
1480 ("doc1".into(), vec![0.1, 0.2, 0.3]),
1481 ("doc2".into(), vec![0.4, 0.5, 0.6]),
1482 ],
1483 },
1484 expire_ms: -1,
1485 }];
1486
1487 {
1488 let mut writer = SnapshotWriter::create(&path, 0)?;
1489 for entry in &entries {
1490 writer.write_entry(entry)?;
1491 }
1492 writer.finish()?;
1493 }
1494
1495 let mut reader = SnapshotReader::open(&path)?;
1496 let mut got = Vec::new();
1497 while let Some(entry) = reader.read_entry()? {
1498 got.push(entry);
1499 }
1500 assert_eq!(entries, got);
1501 reader.verify_footer()?;
1502 Ok(())
1503 }
1504
1505 #[cfg(feature = "vector")]
1506 #[test]
1507 fn vector_empty_set_round_trip() -> Result {
1508 let dir = temp_dir();
1509 let path = dir.path().join("vec_empty.snap");
1510
1511 let entries = vec![SnapEntry {
1512 key: "empty_vecs".into(),
1513 value: SnapValue::Vector {
1514 metric: 2, quantization: 2,
1516 connectivity: 8,
1517 expansion_add: 32,
1518 dim: 128,
1519 elements: vec![],
1520 },
1521 expire_ms: 5000,
1522 }];
1523
1524 {
1525 let mut writer = SnapshotWriter::create(&path, 0)?;
1526 for entry in &entries {
1527 writer.write_entry(entry)?;
1528 }
1529 writer.finish()?;
1530 }
1531
1532 let mut reader = SnapshotReader::open(&path)?;
1533 let got = reader.read_entry()?.unwrap();
1534 assert_eq!(entries[0], got);
1535 reader.verify_footer()?;
1536 Ok(())
1537 }
1538
1539 #[cfg(feature = "encryption")]
1540 mod encrypted {
1541 use super::*;
1542 use crate::encryption::EncryptionKey;
1543
1544 type Result = std::result::Result<(), Box<dyn std::error::Error>>;
1545
1546 fn test_key() -> EncryptionKey {
1547 EncryptionKey::from_bytes([0x42; 32])
1548 }
1549
1550 #[test]
1551 fn encrypted_snapshot_round_trip() -> Result {
1552 let dir = temp_dir();
1553 let path = dir.path().join("enc.snap");
1554 let key = test_key();
1555
1556 let entries = vec![
1557 SnapEntry {
1558 key: "hello".into(),
1559 value: SnapValue::String(Bytes::from("world")),
1560 expire_ms: -1,
1561 },
1562 SnapEntry {
1563 key: "ttl".into(),
1564 value: SnapValue::String(Bytes::from("expiring")),
1565 expire_ms: 5000,
1566 },
1567 ];
1568
1569 {
1570 let mut writer = SnapshotWriter::create_encrypted(&path, 7, key.clone())?;
1571 for entry in &entries {
1572 writer.write_entry(entry)?;
1573 }
1574 writer.finish()?;
1575 }
1576
1577 let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1578 assert_eq!(reader.shard_id, 7);
1579 assert_eq!(reader.entry_count, 2);
1580
1581 let mut got = Vec::new();
1582 while let Some(entry) = reader.read_entry()? {
1583 got.push(entry);
1584 }
1585 assert_eq!(entries, got);
1586 reader.verify_footer()?;
1587 Ok(())
1588 }
1589
1590 #[test]
1591 fn encrypted_snapshot_wrong_key_fails() -> Result {
1592 let dir = temp_dir();
1593 let path = dir.path().join("enc_bad.snap");
1594 let key = test_key();
1595 let wrong_key = EncryptionKey::from_bytes([0xFF; 32]);
1596
1597 {
1598 let mut writer = SnapshotWriter::create_encrypted(&path, 0, key)?;
1599 writer.write_entry(&SnapEntry {
1600 key: "k".into(),
1601 value: SnapValue::String(Bytes::from("v")),
1602 expire_ms: -1,
1603 })?;
1604 writer.finish()?;
1605 }
1606
1607 let mut reader = SnapshotReader::open_encrypted(&path, wrong_key)?;
1608 let err = reader.read_entry().unwrap_err();
1609 assert!(matches!(err, FormatError::DecryptionFailed));
1610 Ok(())
1611 }
1612
1613 #[test]
1614 fn v2_snapshot_readable_with_encryption_key() -> Result {
1615 let dir = temp_dir();
1616 let path = dir.path().join("v2.snap");
1617 let key = test_key();
1618
1619 {
1620 let mut writer = SnapshotWriter::create(&path, 0)?;
1621 writer.write_entry(&SnapEntry {
1622 key: "k".into(),
1623 value: SnapValue::String(Bytes::from("v")),
1624 expire_ms: -1,
1625 })?;
1626 writer.finish()?;
1627 }
1628
1629 let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1630 let entry = reader.read_entry()?.unwrap();
1631 assert_eq!(entry.key, "k");
1632 reader.verify_footer()?;
1633 Ok(())
1634 }
1635
1636 #[test]
1637 fn v3_snapshot_without_key_returns_error() -> Result {
1638 let dir = temp_dir();
1639 let path = dir.path().join("v3_nokey.snap");
1640 let key = test_key();
1641
1642 {
1643 let mut writer = SnapshotWriter::create_encrypted(&path, 0, key)?;
1644 writer.write_entry(&SnapEntry {
1645 key: "k".into(),
1646 value: SnapValue::String(Bytes::from("v")),
1647 expire_ms: -1,
1648 })?;
1649 writer.finish()?;
1650 }
1651
1652 let result = SnapshotReader::open(&path);
1653 assert!(matches!(result, Err(FormatError::EncryptionRequired)));
1654 Ok(())
1655 }
1656
1657 #[test]
1658 fn encrypted_snapshot_with_all_types() -> Result {
1659 let dir = temp_dir();
1660 let path = dir.path().join("enc_types.snap");
1661 let key = test_key();
1662
1663 let mut deque = VecDeque::new();
1664 deque.push_back(Bytes::from("a"));
1665 deque.push_back(Bytes::from("b"));
1666
1667 let mut hash = HashMap::new();
1668 hash.insert("f1".into(), Bytes::from("v1"));
1669
1670 let mut set = HashSet::new();
1671 set.insert("m1".into());
1672 set.insert("m2".into());
1673
1674 let entries = vec![
1675 SnapEntry {
1676 key: "str".into(),
1677 value: SnapValue::String(Bytes::from("val")),
1678 expire_ms: -1,
1679 },
1680 SnapEntry {
1681 key: "list".into(),
1682 value: SnapValue::List(deque),
1683 expire_ms: 1000,
1684 },
1685 SnapEntry {
1686 key: "zset".into(),
1687 value: SnapValue::SortedSet(vec![(1.0, "a".into()), (2.0, "b".into())]),
1688 expire_ms: -1,
1689 },
1690 SnapEntry {
1691 key: "hash".into(),
1692 value: SnapValue::Hash(hash),
1693 expire_ms: -1,
1694 },
1695 SnapEntry {
1696 key: "set".into(),
1697 value: SnapValue::Set(set),
1698 expire_ms: -1,
1699 },
1700 ];
1701
1702 {
1703 let mut writer = SnapshotWriter::create_encrypted(&path, 0, key.clone())?;
1704 for entry in &entries {
1705 writer.write_entry(entry)?;
1706 }
1707 writer.finish()?;
1708 }
1709
1710 let mut reader = SnapshotReader::open_encrypted(&path, key)?;
1711 let mut got = Vec::new();
1712 while let Some(entry) = reader.read_entry()? {
1713 got.push(entry);
1714 }
1715 assert_eq!(entries, got);
1716 reader.verify_footer()?;
1717 Ok(())
1718 }
1719 }
1720
1721 #[test]
1724 fn snap_value_roundtrip_string() {
1725 let original = SnapValue::String(Bytes::from("hello world"));
1726 let data = serialize_snap_value(&original).unwrap();
1727 let decoded = deserialize_snap_value(&data).unwrap();
1728 assert_eq!(original, decoded);
1729 }
1730
1731 #[test]
1732 fn snap_value_roundtrip_list() {
1733 let original = SnapValue::List(VecDeque::from([
1734 Bytes::from("a"),
1735 Bytes::from("b"),
1736 Bytes::from("c"),
1737 ]));
1738 let data = serialize_snap_value(&original).unwrap();
1739 let decoded = deserialize_snap_value(&data).unwrap();
1740 assert_eq!(original, decoded);
1741 }
1742
1743 #[test]
1744 fn snap_value_roundtrip_sorted_set() {
1745 let original = SnapValue::SortedSet(vec![(1.5, "alice".into()), (2.7, "bob".into())]);
1746 let data = serialize_snap_value(&original).unwrap();
1747 let decoded = deserialize_snap_value(&data).unwrap();
1748 assert_eq!(original, decoded);
1749 }
1750
1751 #[test]
1752 fn snap_value_roundtrip_hash() {
1753 let mut map = HashMap::new();
1754 map.insert("field1".into(), Bytes::from("val1"));
1755 map.insert("field2".into(), Bytes::from("val2"));
1756 let original = SnapValue::Hash(map);
1757 let data = serialize_snap_value(&original).unwrap();
1758 let decoded = deserialize_snap_value(&data).unwrap();
1759 assert_eq!(original, decoded);
1760 }
1761
1762 #[test]
1763 fn snap_value_roundtrip_set() {
1764 let mut set = HashSet::new();
1765 set.insert("x".into());
1766 set.insert("y".into());
1767 set.insert("z".into());
1768 let original = SnapValue::Set(set);
1769 let data = serialize_snap_value(&original).unwrap();
1770 let decoded = deserialize_snap_value(&data).unwrap();
1771 assert_eq!(original, decoded);
1772 }
1773
1774 #[test]
1775 fn snap_value_roundtrip_empty_string() {
1776 let original = SnapValue::String(Bytes::new());
1777 let data = serialize_snap_value(&original).unwrap();
1778 let decoded = deserialize_snap_value(&data).unwrap();
1779 assert_eq!(original, decoded);
1780 }
1781
1782 #[test]
1783 fn deserialize_invalid_data() {
1784 assert!(deserialize_snap_value(&[]).is_err());
1785 assert!(deserialize_snap_value(&[255]).is_err());
1786 }
1787}