1use std::collections::{HashMap, HashSet, VecDeque};
25use std::fs::{self, File, OpenOptions};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37const TYPE_HASH: u8 = 3;
38const TYPE_SET: u8 = 4;
39
40#[derive(Debug, Clone, PartialEq)]
42pub enum SnapValue {
43 String(Bytes),
45 List(VecDeque<Bytes>),
47 SortedSet(Vec<(f64, String)>),
49 Hash(HashMap<String, Bytes>),
51 Set(HashSet<String>),
53}
54
55#[derive(Debug, Clone, PartialEq)]
57pub struct SnapEntry {
58 pub key: String,
59 pub value: SnapValue,
60 pub expire_ms: i64,
62}
63
64pub struct SnapshotWriter {
70 final_path: PathBuf,
71 tmp_path: PathBuf,
72 writer: BufWriter<File>,
73 hasher: crc32fast::Hasher,
75 count: u32,
76 finished: bool,
79}
80
81impl SnapshotWriter {
82 pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
85 let final_path = path.into();
86 let tmp_path = final_path.with_extension("snap.tmp");
87
88 let file = {
89 let mut opts = OpenOptions::new();
90 opts.write(true).create(true).truncate(true);
91 #[cfg(unix)]
92 {
93 use std::os::unix::fs::OpenOptionsExt;
94 opts.mode(0o600);
95 }
96 opts.open(&tmp_path)?
97 };
98 let mut writer = BufWriter::new(file);
99
100 format::write_header(&mut writer, format::SNAP_MAGIC)?;
102 format::write_u16(&mut writer, shard_id)?;
103 format::write_u32(&mut writer, 0)?;
106
107 Ok(Self {
108 final_path,
109 tmp_path,
110 writer,
111 hasher: crc32fast::Hasher::new(),
112 count: 0,
113 finished: false,
114 })
115 }
116
117 pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
119 let mut buf = Vec::new();
120 format::write_bytes(&mut buf, entry.key.as_bytes())?;
121 match &entry.value {
122 SnapValue::String(data) => {
123 format::write_u8(&mut buf, TYPE_STRING)?;
124 format::write_bytes(&mut buf, data)?;
125 }
126 SnapValue::List(deque) => {
127 format::write_u8(&mut buf, TYPE_LIST)?;
128 format::write_u32(&mut buf, deque.len() as u32)?;
129 for item in deque {
130 format::write_bytes(&mut buf, item)?;
131 }
132 }
133 SnapValue::SortedSet(members) => {
134 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
135 format::write_u32(&mut buf, members.len() as u32)?;
136 for (score, member) in members {
137 format::write_f64(&mut buf, *score)?;
138 format::write_bytes(&mut buf, member.as_bytes())?;
139 }
140 }
141 SnapValue::Hash(map) => {
142 format::write_u8(&mut buf, TYPE_HASH)?;
143 format::write_u32(&mut buf, map.len() as u32)?;
144 for (field, value) in map {
145 format::write_bytes(&mut buf, field.as_bytes())?;
146 format::write_bytes(&mut buf, value)?;
147 }
148 }
149 SnapValue::Set(set) => {
150 format::write_u8(&mut buf, TYPE_SET)?;
151 format::write_u32(&mut buf, set.len() as u32)?;
152 for member in set {
153 format::write_bytes(&mut buf, member.as_bytes())?;
154 }
155 }
156 }
157 format::write_i64(&mut buf, entry.expire_ms)?;
158
159 self.hasher.update(&buf);
160 self.writer.write_all(&buf)?;
161 self.count += 1;
162 Ok(())
163 }
164
165 pub fn finish(mut self) -> Result<(), FormatError> {
168 let checksum = self.hasher.clone().finalize();
170 format::write_u32(&mut self.writer, checksum)?;
171 self.writer.flush()?;
172 self.writer.get_ref().sync_all()?;
173
174 {
178 use std::io::{Seek, SeekFrom};
179 let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
180 file.seek(SeekFrom::Start(7))?;
182 format::write_u32(&mut file, self.count)?;
183 file.sync_all()?;
184 }
185
186 fs::rename(&self.tmp_path, &self.final_path)?;
188 self.finished = true;
189 Ok(())
190 }
191}
192
193impl Drop for SnapshotWriter {
194 fn drop(&mut self) {
195 if !self.finished {
196 let _ = fs::remove_file(&self.tmp_path);
198 }
199 }
200}
201
202pub struct SnapshotReader {
204 reader: BufReader<File>,
205 pub shard_id: u16,
206 pub entry_count: u32,
207 read_so_far: u32,
208 hasher: crc32fast::Hasher,
209 version: u8,
211}
212
213impl SnapshotReader {
214 pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
216 let file = File::open(path.as_ref())?;
217 let mut reader = BufReader::new(file);
218
219 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
220 let shard_id = format::read_u16(&mut reader)?;
221 let entry_count = format::read_u32(&mut reader)?;
222
223 Ok(Self {
224 reader,
225 shard_id,
226 entry_count,
227 read_so_far: 0,
228 hasher: crc32fast::Hasher::new(),
229 version,
230 })
231 }
232
233 pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
235 if self.read_so_far >= self.entry_count {
236 return Ok(None);
237 }
238
239 let mut buf = Vec::new();
240
241 let key_bytes = format::read_bytes(&mut self.reader)?;
242 format::write_bytes(&mut buf, &key_bytes)?;
243
244 let value = if self.version == 1 {
245 let value_bytes = format::read_bytes(&mut self.reader)?;
247 format::write_bytes(&mut buf, &value_bytes)?;
248 SnapValue::String(Bytes::from(value_bytes))
249 } else {
250 let type_tag = format::read_u8(&mut self.reader)?;
252 format::write_u8(&mut buf, type_tag)?;
253 match type_tag {
254 TYPE_STRING => {
255 let value_bytes = format::read_bytes(&mut self.reader)?;
256 format::write_bytes(&mut buf, &value_bytes)?;
257 SnapValue::String(Bytes::from(value_bytes))
258 }
259 TYPE_LIST => {
260 let count = format::read_u32(&mut self.reader)?;
261 format::write_u32(&mut buf, count)?;
262 let mut deque = VecDeque::with_capacity(count as usize);
263 for _ in 0..count {
264 let item = format::read_bytes(&mut self.reader)?;
265 format::write_bytes(&mut buf, &item)?;
266 deque.push_back(Bytes::from(item));
267 }
268 SnapValue::List(deque)
269 }
270 TYPE_SORTED_SET => {
271 let count = format::read_u32(&mut self.reader)?;
272 format::write_u32(&mut buf, count)?;
273 let mut members = Vec::with_capacity(count as usize);
274 for _ in 0..count {
275 let score = format::read_f64(&mut self.reader)?;
276 format::write_f64(&mut buf, score)?;
277 let member_bytes = format::read_bytes(&mut self.reader)?;
278 format::write_bytes(&mut buf, &member_bytes)?;
279 let member = String::from_utf8(member_bytes).map_err(|_| {
280 FormatError::Io(io::Error::new(
281 io::ErrorKind::InvalidData,
282 "member is not valid utf-8",
283 ))
284 })?;
285 members.push((score, member));
286 }
287 SnapValue::SortedSet(members)
288 }
289 TYPE_HASH => {
290 let count = format::read_u32(&mut self.reader)?;
291 format::write_u32(&mut buf, count)?;
292 let mut map = HashMap::with_capacity(count as usize);
293 for _ in 0..count {
294 let field_bytes = format::read_bytes(&mut self.reader)?;
295 format::write_bytes(&mut buf, &field_bytes)?;
296 let field = String::from_utf8(field_bytes).map_err(|_| {
297 FormatError::Io(io::Error::new(
298 io::ErrorKind::InvalidData,
299 "hash field is not valid utf-8",
300 ))
301 })?;
302 let value_bytes = format::read_bytes(&mut self.reader)?;
303 format::write_bytes(&mut buf, &value_bytes)?;
304 map.insert(field, Bytes::from(value_bytes));
305 }
306 SnapValue::Hash(map)
307 }
308 TYPE_SET => {
309 let count = format::read_u32(&mut self.reader)?;
310 format::write_u32(&mut buf, count)?;
311 let mut set = HashSet::with_capacity(count as usize);
312 for _ in 0..count {
313 let member_bytes = format::read_bytes(&mut self.reader)?;
314 format::write_bytes(&mut buf, &member_bytes)?;
315 let member = String::from_utf8(member_bytes).map_err(|_| {
316 FormatError::Io(io::Error::new(
317 io::ErrorKind::InvalidData,
318 "set member is not valid utf-8",
319 ))
320 })?;
321 set.insert(member);
322 }
323 SnapValue::Set(set)
324 }
325 _ => {
326 return Err(FormatError::UnknownTag(type_tag));
327 }
328 }
329 };
330
331 let expire_ms = format::read_i64(&mut self.reader)?;
332 format::write_i64(&mut buf, expire_ms)?;
333 self.hasher.update(&buf);
334
335 let key = String::from_utf8(key_bytes).map_err(|_| {
336 FormatError::Io(io::Error::new(
337 io::ErrorKind::InvalidData,
338 "key is not valid utf-8",
339 ))
340 })?;
341
342 self.read_so_far += 1;
343 Ok(Some(SnapEntry {
344 key,
345 value,
346 expire_ms,
347 }))
348 }
349
350 pub fn verify_footer(self) -> Result<(), FormatError> {
353 let expected = self.hasher.finalize();
354 let mut reader = self.reader;
355 let stored = format::read_u32(&mut reader)?;
356 format::verify_crc32_values(expected, stored)
357 }
358}
359
360pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
362 data_dir.join(format!("shard-{shard_id}.snap"))
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 fn temp_dir() -> tempfile::TempDir {
370 tempfile::tempdir().expect("create temp dir")
371 }
372
373 #[test]
374 fn empty_snapshot_round_trip() {
375 let dir = temp_dir();
376 let path = dir.path().join("empty.snap");
377
378 {
379 let writer = SnapshotWriter::create(&path, 0).unwrap();
380 writer.finish().unwrap();
381 }
382
383 let reader = SnapshotReader::open(&path).unwrap();
384 assert_eq!(reader.shard_id, 0);
385 assert_eq!(reader.entry_count, 0);
386 reader.verify_footer().unwrap();
387 }
388
389 #[test]
390 fn entries_round_trip() {
391 let dir = temp_dir();
392 let path = dir.path().join("data.snap");
393
394 let entries = vec![
395 SnapEntry {
396 key: "hello".into(),
397 value: SnapValue::String(Bytes::from("world")),
398 expire_ms: -1,
399 },
400 SnapEntry {
401 key: "ttl".into(),
402 value: SnapValue::String(Bytes::from("expiring")),
403 expire_ms: 5000,
404 },
405 SnapEntry {
406 key: "empty".into(),
407 value: SnapValue::String(Bytes::new()),
408 expire_ms: -1,
409 },
410 ];
411
412 {
413 let mut writer = SnapshotWriter::create(&path, 7).unwrap();
414 for entry in &entries {
415 writer.write_entry(entry).unwrap();
416 }
417 writer.finish().unwrap();
418 }
419
420 let mut reader = SnapshotReader::open(&path).unwrap();
421 assert_eq!(reader.shard_id, 7);
422 assert_eq!(reader.entry_count, 3);
423
424 let mut got = Vec::new();
425 while let Some(entry) = reader.read_entry().unwrap() {
426 got.push(entry);
427 }
428 assert_eq!(entries, got);
429 reader.verify_footer().unwrap();
430 }
431
432 #[test]
433 fn corrupt_footer_detected() {
434 let dir = temp_dir();
435 let path = dir.path().join("corrupt.snap");
436
437 {
438 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
439 writer
440 .write_entry(&SnapEntry {
441 key: "k".into(),
442 value: SnapValue::String(Bytes::from("v")),
443 expire_ms: -1,
444 })
445 .unwrap();
446 writer.finish().unwrap();
447 }
448
449 let mut data = fs::read(&path).unwrap();
451 let last = data.len() - 1;
452 data[last] ^= 0xFF;
453 fs::write(&path, &data).unwrap();
454
455 let mut reader = SnapshotReader::open(&path).unwrap();
456 reader.read_entry().unwrap();
458 let err = reader.verify_footer().unwrap_err();
460 assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
461 }
462
463 #[test]
464 fn atomic_rename_prevents_partial_snapshots() {
465 let dir = temp_dir();
466 let path = dir.path().join("atomic.snap");
467
468 {
470 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
471 writer
472 .write_entry(&SnapEntry {
473 key: "original".into(),
474 value: SnapValue::String(Bytes::from("data")),
475 expire_ms: -1,
476 })
477 .unwrap();
478 writer.finish().unwrap();
479 }
480
481 {
483 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
484 writer
485 .write_entry(&SnapEntry {
486 key: "new".into(),
487 value: SnapValue::String(Bytes::from("partial")),
488 expire_ms: -1,
489 })
490 .unwrap();
491 drop(writer);
493 }
494
495 let mut reader = SnapshotReader::open(&path).unwrap();
497 let entry = reader.read_entry().unwrap().unwrap();
498 assert_eq!(entry.key, "original");
499
500 let tmp = path.with_extension("snap.tmp");
502 assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
503 }
504
505 #[test]
506 fn ttl_entries_preserved() {
507 let dir = temp_dir();
508 let path = dir.path().join("ttl.snap");
509
510 let entry = SnapEntry {
511 key: "expires".into(),
512 value: SnapValue::String(Bytes::from("soon")),
513 expire_ms: 42_000,
514 };
515
516 {
517 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
518 writer.write_entry(&entry).unwrap();
519 writer.finish().unwrap();
520 }
521
522 let mut reader = SnapshotReader::open(&path).unwrap();
523 let got = reader.read_entry().unwrap().unwrap();
524 assert_eq!(got.expire_ms, 42_000);
525 reader.verify_footer().unwrap();
526 }
527
528 #[test]
529 fn list_entries_round_trip() {
530 let dir = temp_dir();
531 let path = dir.path().join("list.snap");
532
533 let mut deque = VecDeque::new();
534 deque.push_back(Bytes::from("a"));
535 deque.push_back(Bytes::from("b"));
536 deque.push_back(Bytes::from("c"));
537
538 let entries = vec![
539 SnapEntry {
540 key: "mylist".into(),
541 value: SnapValue::List(deque),
542 expire_ms: -1,
543 },
544 SnapEntry {
545 key: "mystr".into(),
546 value: SnapValue::String(Bytes::from("val")),
547 expire_ms: 1000,
548 },
549 ];
550
551 {
552 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
553 for entry in &entries {
554 writer.write_entry(entry).unwrap();
555 }
556 writer.finish().unwrap();
557 }
558
559 let mut reader = SnapshotReader::open(&path).unwrap();
560 let mut got = Vec::new();
561 while let Some(entry) = reader.read_entry().unwrap() {
562 got.push(entry);
563 }
564 assert_eq!(entries, got);
565 reader.verify_footer().unwrap();
566 }
567
568 #[test]
569 fn sorted_set_entries_round_trip() {
570 let dir = temp_dir();
571 let path = dir.path().join("zset.snap");
572
573 let entries = vec![
574 SnapEntry {
575 key: "board".into(),
576 value: SnapValue::SortedSet(vec![
577 (100.0, "alice".into()),
578 (200.0, "bob".into()),
579 (150.0, "charlie".into()),
580 ]),
581 expire_ms: -1,
582 },
583 SnapEntry {
584 key: "mystr".into(),
585 value: SnapValue::String(Bytes::from("val")),
586 expire_ms: 1000,
587 },
588 ];
589
590 {
591 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
592 for entry in &entries {
593 writer.write_entry(entry).unwrap();
594 }
595 writer.finish().unwrap();
596 }
597
598 let mut reader = SnapshotReader::open(&path).unwrap();
599 let mut got = Vec::new();
600 while let Some(entry) = reader.read_entry().unwrap() {
601 got.push(entry);
602 }
603 assert_eq!(entries, got);
604 reader.verify_footer().unwrap();
605 }
606
607 #[test]
608 fn snapshot_path_format() {
609 let p = snapshot_path(Path::new("/data"), 5);
610 assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
611 }
612}