1use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39
40#[derive(Debug, Clone, PartialEq)]
42pub enum SnapValue {
43 String(Bytes),
45 List(VecDeque<Bytes>),
47 SortedSet(Vec<(f64, String)>),
49 Hash(HashMap<String, Bytes>),
51 Set(HashSet<String>),
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub struct SnapEntry {
58 pub key: String,
59 pub value: SnapValue,
60 pub expire_ms: i64,
62}
63
64pub struct SnapshotWriter {
70 final_path: PathBuf,
71 tmp_path: PathBuf,
72 writer: BufWriter<File>,
73 hasher: crc32fast::Hasher,
75 count: u32,
76 finished: bool,
79}
80
81impl SnapshotWriter {
82 pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
85 let final_path = path.into();
86 let tmp_path = final_path.with_extension("snap.tmp");
87
88 let file = File::create(&tmp_path)?;
89 let mut writer = BufWriter::new(file);
90
91 format::write_header(&mut writer, format::SNAP_MAGIC)?;
93 format::write_u16(&mut writer, shard_id)?;
94 format::write_u32(&mut writer, 0)?;
97
98 Ok(Self {
99 final_path,
100 tmp_path,
101 writer,
102 hasher: crc32fast::Hasher::new(),
103 count: 0,
104 finished: false,
105 })
106 }
107
108 pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
110 let mut buf = Vec::new();
111 format::write_bytes(&mut buf, entry.key.as_bytes())?;
112 match &entry.value {
113 SnapValue::String(data) => {
114 format::write_u8(&mut buf, TYPE_STRING)?;
115 format::write_bytes(&mut buf, data)?;
116 }
117 SnapValue::List(deque) => {
118 format::write_u8(&mut buf, TYPE_LIST)?;
119 format::write_u32(&mut buf, deque.len() as u32)?;
120 for item in deque {
121 format::write_bytes(&mut buf, item)?;
122 }
123 }
124 SnapValue::SortedSet(members) => {
125 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
126 format::write_u32(&mut buf, members.len() as u32)?;
127 for (score, member) in members {
128 format::write_f64(&mut buf, *score)?;
129 format::write_bytes(&mut buf, member.as_bytes())?;
130 }
131 }
132 SnapValue::Hash(map) => {
133 format::write_u8(&mut buf, TYPE_HASH)?;
134 format::write_u32(&mut buf, map.len() as u32)?;
135 for (field, value) in map {
136 format::write_bytes(&mut buf, field.as_bytes())?;
137 format::write_bytes(&mut buf, value)?;
138 }
139 }
140 SnapValue::Set(set) => {
141 format::write_u8(&mut buf, TYPE_SET)?;
142 format::write_u32(&mut buf, set.len() as u32)?;
143 for member in set {
144 format::write_bytes(&mut buf, member.as_bytes())?;
145 }
146 }
147 }
148 format::write_i64(&mut buf, entry.expire_ms)?;
149
150 self.hasher.update(&buf);
151 self.writer.write_all(&buf)?;
152 self.count += 1;
153 Ok(())
154 }
155
156 pub fn finish(mut self) -> Result<(), FormatError> {
159 let checksum = self.hasher.clone().finalize();
161 format::write_u32(&mut self.writer, checksum)?;
162 self.writer.flush()?;
163 self.writer.get_ref().sync_all()?;
164
165 {
169 use std::io::{Seek, SeekFrom};
170 let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
171 file.seek(SeekFrom::Start(7))?;
173 format::write_u32(&mut file, self.count)?;
174 file.sync_all()?;
175 }
176
177 fs::rename(&self.tmp_path, &self.final_path)?;
179 self.finished = true;
180 Ok(())
181 }
182}
183
184impl Drop for SnapshotWriter {
185 fn drop(&mut self) {
186 if !self.finished {
187 let _ = fs::remove_file(&self.tmp_path);
189 }
190 }
191}
192
193pub struct SnapshotReader {
195 reader: BufReader<File>,
196 pub shard_id: u16,
197 pub entry_count: u32,
198 read_so_far: u32,
199 hasher: crc32fast::Hasher,
200 version: u8,
202}
203
204impl SnapshotReader {
205 pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
207 let file = File::open(path.as_ref())?;
208 let mut reader = BufReader::new(file);
209
210 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
211 let shard_id = format::read_u16(&mut reader)?;
212 let entry_count = format::read_u32(&mut reader)?;
213
214 Ok(Self {
215 reader,
216 shard_id,
217 entry_count,
218 read_so_far: 0,
219 hasher: crc32fast::Hasher::new(),
220 version,
221 })
222 }
223
224 pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
226 if self.read_so_far >= self.entry_count {
227 return Ok(None);
228 }
229
230 let mut buf = Vec::new();
231
232 let key_bytes = format::read_bytes(&mut self.reader)?;
233 format::write_bytes(&mut buf, &key_bytes).expect("vec write");
234
235 let value = if self.version == 1 {
236 let value_bytes = format::read_bytes(&mut self.reader)?;
238 format::write_bytes(&mut buf, &value_bytes).expect("vec write");
239 SnapValue::String(Bytes::from(value_bytes))
240 } else {
241 let type_tag = format::read_u8(&mut self.reader)?;
243 format::write_u8(&mut buf, type_tag).expect("vec write");
244 match type_tag {
245 TYPE_STRING => {
246 let value_bytes = format::read_bytes(&mut self.reader)?;
247 format::write_bytes(&mut buf, &value_bytes).expect("vec write");
248 SnapValue::String(Bytes::from(value_bytes))
249 }
250 TYPE_LIST => {
251 let count = format::read_u32(&mut self.reader)?;
252 format::write_u32(&mut buf, count).expect("vec write");
253 let mut deque = VecDeque::with_capacity(count as usize);
254 for _ in 0..count {
255 let item = format::read_bytes(&mut self.reader)?;
256 format::write_bytes(&mut buf, &item).expect("vec write");
257 deque.push_back(Bytes::from(item));
258 }
259 SnapValue::List(deque)
260 }
261 TYPE_SORTED_SET => {
262 let count = format::read_u32(&mut self.reader)?;
263 format::write_u32(&mut buf, count).expect("vec write");
264 let mut members = Vec::with_capacity(count as usize);
265 for _ in 0..count {
266 let score = format::read_f64(&mut self.reader)?;
267 format::write_f64(&mut buf, score).expect("vec write");
268 let member_bytes = format::read_bytes(&mut self.reader)?;
269 format::write_bytes(&mut buf, &member_bytes).expect("vec write");
270 let member = String::from_utf8(member_bytes).map_err(|_| {
271 FormatError::Io(io::Error::new(
272 io::ErrorKind::InvalidData,
273 "member is not valid utf-8",
274 ))
275 })?;
276 members.push((score, member));
277 }
278 SnapValue::SortedSet(members)
279 }
280 TYPE_HASH => {
281 let count = format::read_u32(&mut self.reader)?;
282 format::write_u32(&mut buf, count).expect("vec write");
283 let mut map = HashMap::with_capacity(count as usize);
284 for _ in 0..count {
285 let field_bytes = format::read_bytes(&mut self.reader)?;
286 format::write_bytes(&mut buf, &field_bytes).expect("vec write");
287 let field = String::from_utf8(field_bytes).map_err(|_| {
288 FormatError::Io(io::Error::new(
289 io::ErrorKind::InvalidData,
290 "hash field is not valid utf-8",
291 ))
292 })?;
293 let value_bytes = format::read_bytes(&mut self.reader)?;
294 format::write_bytes(&mut buf, &value_bytes).expect("vec write");
295 map.insert(field, Bytes::from(value_bytes));
296 }
297 SnapValue::Hash(map)
298 }
299 TYPE_SET => {
300 let count = format::read_u32(&mut self.reader)?;
301 format::write_u32(&mut buf, count).expect("vec write");
302 let mut set = HashSet::with_capacity(count as usize);
303 for _ in 0..count {
304 let member_bytes = format::read_bytes(&mut self.reader)?;
305 format::write_bytes(&mut buf, &member_bytes).expect("vec write");
306 let member = String::from_utf8(member_bytes).map_err(|_| {
307 FormatError::Io(io::Error::new(
308 io::ErrorKind::InvalidData,
309 "set member is not valid utf-8",
310 ))
311 })?;
312 set.insert(member);
313 }
314 SnapValue::Set(set)
315 }
316 _ => {
317 return Err(FormatError::UnknownTag(type_tag));
318 }
319 }
320 };
321
322 let expire_ms = format::read_i64(&mut self.reader)?;
323 format::write_i64(&mut buf, expire_ms).expect("vec write");
324 self.hasher.update(&buf);
325
326 let key = String::from_utf8(key_bytes).map_err(|_| {
327 FormatError::Io(io::Error::new(
328 io::ErrorKind::InvalidData,
329 "key is not valid utf-8",
330 ))
331 })?;
332
333 self.read_so_far += 1;
334 Ok(Some(SnapEntry {
335 key,
336 value,
337 expire_ms,
338 }))
339 }
340
341 pub fn verify_footer(self) -> Result<(), FormatError> {
344 let expected = self.hasher.finalize();
345 let mut reader = self.reader;
346 let stored = format::read_u32(&mut reader)?;
347 format::verify_crc32_values(expected, stored)
348 }
349}
350
351pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
353 data_dir.join(format!("shard-{shard_id}.snap"))
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 fn temp_dir() -> tempfile::TempDir {
361 tempfile::tempdir().expect("create temp dir")
362 }
363
364 #[test]
365 fn empty_snapshot_round_trip() {
366 let dir = temp_dir();
367 let path = dir.path().join("empty.snap");
368
369 {
370 let writer = SnapshotWriter::create(&path, 0).unwrap();
371 writer.finish().unwrap();
372 }
373
374 let reader = SnapshotReader::open(&path).unwrap();
375 assert_eq!(reader.shard_id, 0);
376 assert_eq!(reader.entry_count, 0);
377 reader.verify_footer().unwrap();
378 }
379
380 #[test]
381 fn entries_round_trip() {
382 let dir = temp_dir();
383 let path = dir.path().join("data.snap");
384
385 let entries = vec![
386 SnapEntry {
387 key: "hello".into(),
388 value: SnapValue::String(Bytes::from("world")),
389 expire_ms: -1,
390 },
391 SnapEntry {
392 key: "ttl".into(),
393 value: SnapValue::String(Bytes::from("expiring")),
394 expire_ms: 5000,
395 },
396 SnapEntry {
397 key: "empty".into(),
398 value: SnapValue::String(Bytes::new()),
399 expire_ms: -1,
400 },
401 ];
402
403 {
404 let mut writer = SnapshotWriter::create(&path, 7).unwrap();
405 for entry in &entries {
406 writer.write_entry(entry).unwrap();
407 }
408 writer.finish().unwrap();
409 }
410
411 let mut reader = SnapshotReader::open(&path).unwrap();
412 assert_eq!(reader.shard_id, 7);
413 assert_eq!(reader.entry_count, 3);
414
415 let mut got = Vec::new();
416 while let Some(entry) = reader.read_entry().unwrap() {
417 got.push(entry);
418 }
419 assert_eq!(entries, got);
420 reader.verify_footer().unwrap();
421 }
422
423 #[test]
424 fn corrupt_footer_detected() {
425 let dir = temp_dir();
426 let path = dir.path().join("corrupt.snap");
427
428 {
429 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
430 writer
431 .write_entry(&SnapEntry {
432 key: "k".into(),
433 value: SnapValue::String(Bytes::from("v")),
434 expire_ms: -1,
435 })
436 .unwrap();
437 writer.finish().unwrap();
438 }
439
440 let mut data = fs::read(&path).unwrap();
442 let last = data.len() - 1;
443 data[last] ^= 0xFF;
444 fs::write(&path, &data).unwrap();
445
446 let mut reader = SnapshotReader::open(&path).unwrap();
447 reader.read_entry().unwrap();
449 let err = reader.verify_footer().unwrap_err();
451 assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
452 }
453
454 #[test]
455 fn atomic_rename_prevents_partial_snapshots() {
456 let dir = temp_dir();
457 let path = dir.path().join("atomic.snap");
458
459 {
461 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
462 writer
463 .write_entry(&SnapEntry {
464 key: "original".into(),
465 value: SnapValue::String(Bytes::from("data")),
466 expire_ms: -1,
467 })
468 .unwrap();
469 writer.finish().unwrap();
470 }
471
472 {
474 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
475 writer
476 .write_entry(&SnapEntry {
477 key: "new".into(),
478 value: SnapValue::String(Bytes::from("partial")),
479 expire_ms: -1,
480 })
481 .unwrap();
482 drop(writer);
484 }
485
486 let mut reader = SnapshotReader::open(&path).unwrap();
488 let entry = reader.read_entry().unwrap().unwrap();
489 assert_eq!(entry.key, "original");
490
491 let tmp = path.with_extension("snap.tmp");
493 assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
494 }
495
496 #[test]
497 fn ttl_entries_preserved() {
498 let dir = temp_dir();
499 let path = dir.path().join("ttl.snap");
500
501 let entry = SnapEntry {
502 key: "expires".into(),
503 value: SnapValue::String(Bytes::from("soon")),
504 expire_ms: 42_000,
505 };
506
507 {
508 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
509 writer.write_entry(&entry).unwrap();
510 writer.finish().unwrap();
511 }
512
513 let mut reader = SnapshotReader::open(&path).unwrap();
514 let got = reader.read_entry().unwrap().unwrap();
515 assert_eq!(got.expire_ms, 42_000);
516 reader.verify_footer().unwrap();
517 }
518
519 #[test]
520 fn list_entries_round_trip() {
521 let dir = temp_dir();
522 let path = dir.path().join("list.snap");
523
524 let mut deque = VecDeque::new();
525 deque.push_back(Bytes::from("a"));
526 deque.push_back(Bytes::from("b"));
527 deque.push_back(Bytes::from("c"));
528
529 let entries = vec![
530 SnapEntry {
531 key: "mylist".into(),
532 value: SnapValue::List(deque),
533 expire_ms: -1,
534 },
535 SnapEntry {
536 key: "mystr".into(),
537 value: SnapValue::String(Bytes::from("val")),
538 expire_ms: 1000,
539 },
540 ];
541
542 {
543 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
544 for entry in &entries {
545 writer.write_entry(entry).unwrap();
546 }
547 writer.finish().unwrap();
548 }
549
550 let mut reader = SnapshotReader::open(&path).unwrap();
551 let mut got = Vec::new();
552 while let Some(entry) = reader.read_entry().unwrap() {
553 got.push(entry);
554 }
555 assert_eq!(entries, got);
556 reader.verify_footer().unwrap();
557 }
558
559 #[test]
560 fn sorted_set_entries_round_trip() {
561 let dir = temp_dir();
562 let path = dir.path().join("zset.snap");
563
564 let entries = vec![
565 SnapEntry {
566 key: "board".into(),
567 value: SnapValue::SortedSet(vec![
568 (100.0, "alice".into()),
569 (200.0, "bob".into()),
570 (150.0, "charlie".into()),
571 ]),
572 expire_ms: -1,
573 },
574 SnapEntry {
575 key: "mystr".into(),
576 value: SnapValue::String(Bytes::from("val")),
577 expire_ms: 1000,
578 },
579 ];
580
581 {
582 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
583 for entry in &entries {
584 writer.write_entry(entry).unwrap();
585 }
586 writer.finish().unwrap();
587 }
588
589 let mut reader = SnapshotReader::open(&path).unwrap();
590 let mut got = Vec::new();
591 while let Some(entry) = reader.read_entry().unwrap() {
592 got.push(entry);
593 }
594 assert_eq!(entries, got);
595 reader.verify_footer().unwrap();
596 }
597
598 #[test]
599 fn snapshot_path_format() {
600 let p = snapshot_path(Path::new("/data"), 5);
601 assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
602 }
603}