1use std::collections::VecDeque;
25use std::fs::{self, File};
26use std::io::{self, BufReader, BufWriter, Write};
27use std::path::{Path, PathBuf};
28
29use bytes::Bytes;
30
31use crate::format::{self, FormatError};
32
33const TYPE_STRING: u8 = 0;
35const TYPE_LIST: u8 = 1;
36const TYPE_SORTED_SET: u8 = 2;
37
38#[derive(Debug, Clone, PartialEq)]
40pub enum SnapValue {
41 String(Bytes),
43 List(VecDeque<Bytes>),
45 SortedSet(Vec<(f64, String)>),
47}
48
49#[derive(Debug, Clone, PartialEq)]
51pub struct SnapEntry {
52 pub key: String,
53 pub value: SnapValue,
54 pub expire_ms: i64,
56}
57
58pub struct SnapshotWriter {
64 final_path: PathBuf,
65 tmp_path: PathBuf,
66 writer: BufWriter<File>,
67 hasher: crc32fast::Hasher,
69 count: u32,
70 finished: bool,
73}
74
75impl SnapshotWriter {
76 pub fn create(path: impl Into<PathBuf>, shard_id: u16) -> Result<Self, FormatError> {
79 let final_path = path.into();
80 let tmp_path = final_path.with_extension("snap.tmp");
81
82 let file = File::create(&tmp_path)?;
83 let mut writer = BufWriter::new(file);
84
85 format::write_header(&mut writer, format::SNAP_MAGIC)?;
87 format::write_u16(&mut writer, shard_id)?;
88 format::write_u32(&mut writer, 0)?;
91
92 Ok(Self {
93 final_path,
94 tmp_path,
95 writer,
96 hasher: crc32fast::Hasher::new(),
97 count: 0,
98 finished: false,
99 })
100 }
101
102 pub fn write_entry(&mut self, entry: &SnapEntry) -> Result<(), FormatError> {
104 let mut buf = Vec::new();
105 format::write_bytes(&mut buf, entry.key.as_bytes())?;
106 match &entry.value {
107 SnapValue::String(data) => {
108 format::write_u8(&mut buf, TYPE_STRING)?;
109 format::write_bytes(&mut buf, data)?;
110 }
111 SnapValue::List(deque) => {
112 format::write_u8(&mut buf, TYPE_LIST)?;
113 format::write_u32(&mut buf, deque.len() as u32)?;
114 for item in deque {
115 format::write_bytes(&mut buf, item)?;
116 }
117 }
118 SnapValue::SortedSet(members) => {
119 format::write_u8(&mut buf, TYPE_SORTED_SET)?;
120 format::write_u32(&mut buf, members.len() as u32)?;
121 for (score, member) in members {
122 format::write_f64(&mut buf, *score)?;
123 format::write_bytes(&mut buf, member.as_bytes())?;
124 }
125 }
126 }
127 format::write_i64(&mut buf, entry.expire_ms)?;
128
129 self.hasher.update(&buf);
130 self.writer.write_all(&buf)?;
131 self.count += 1;
132 Ok(())
133 }
134
135 pub fn finish(mut self) -> Result<(), FormatError> {
138 let checksum = self.hasher.clone().finalize();
140 format::write_u32(&mut self.writer, checksum)?;
141 self.writer.flush()?;
142 self.writer.get_ref().sync_all()?;
143
144 {
148 use std::io::{Seek, SeekFrom};
149 let mut file = fs::OpenOptions::new().write(true).open(&self.tmp_path)?;
150 file.seek(SeekFrom::Start(7))?;
152 format::write_u32(&mut file, self.count)?;
153 file.sync_all()?;
154 }
155
156 fs::rename(&self.tmp_path, &self.final_path)?;
158 self.finished = true;
159 Ok(())
160 }
161}
162
163impl Drop for SnapshotWriter {
164 fn drop(&mut self) {
165 if !self.finished {
166 let _ = fs::remove_file(&self.tmp_path);
168 }
169 }
170}
171
172pub struct SnapshotReader {
174 reader: BufReader<File>,
175 pub shard_id: u16,
176 pub entry_count: u32,
177 read_so_far: u32,
178 hasher: crc32fast::Hasher,
179 version: u8,
181}
182
183impl SnapshotReader {
184 pub fn open(path: impl AsRef<Path>) -> Result<Self, FormatError> {
186 let file = File::open(path.as_ref())?;
187 let mut reader = BufReader::new(file);
188
189 let version = format::read_header(&mut reader, format::SNAP_MAGIC)?;
190 let shard_id = format::read_u16(&mut reader)?;
191 let entry_count = format::read_u32(&mut reader)?;
192
193 Ok(Self {
194 reader,
195 shard_id,
196 entry_count,
197 read_so_far: 0,
198 hasher: crc32fast::Hasher::new(),
199 version,
200 })
201 }
202
203 pub fn read_entry(&mut self) -> Result<Option<SnapEntry>, FormatError> {
205 if self.read_so_far >= self.entry_count {
206 return Ok(None);
207 }
208
209 let mut buf = Vec::new();
210
211 let key_bytes = format::read_bytes(&mut self.reader)?;
212 format::write_bytes(&mut buf, &key_bytes).expect("vec write");
213
214 let value = if self.version == 1 {
215 let value_bytes = format::read_bytes(&mut self.reader)?;
217 format::write_bytes(&mut buf, &value_bytes).expect("vec write");
218 SnapValue::String(Bytes::from(value_bytes))
219 } else {
220 let type_tag = format::read_u8(&mut self.reader)?;
222 format::write_u8(&mut buf, type_tag).expect("vec write");
223 match type_tag {
224 TYPE_STRING => {
225 let value_bytes = format::read_bytes(&mut self.reader)?;
226 format::write_bytes(&mut buf, &value_bytes).expect("vec write");
227 SnapValue::String(Bytes::from(value_bytes))
228 }
229 TYPE_LIST => {
230 let count = format::read_u32(&mut self.reader)?;
231 format::write_u32(&mut buf, count).expect("vec write");
232 let mut deque = VecDeque::with_capacity(count as usize);
233 for _ in 0..count {
234 let item = format::read_bytes(&mut self.reader)?;
235 format::write_bytes(&mut buf, &item).expect("vec write");
236 deque.push_back(Bytes::from(item));
237 }
238 SnapValue::List(deque)
239 }
240 TYPE_SORTED_SET => {
241 let count = format::read_u32(&mut self.reader)?;
242 format::write_u32(&mut buf, count).expect("vec write");
243 let mut members = Vec::with_capacity(count as usize);
244 for _ in 0..count {
245 let score = format::read_f64(&mut self.reader)?;
246 format::write_f64(&mut buf, score).expect("vec write");
247 let member_bytes = format::read_bytes(&mut self.reader)?;
248 format::write_bytes(&mut buf, &member_bytes).expect("vec write");
249 let member = String::from_utf8(member_bytes).map_err(|_| {
250 FormatError::Io(io::Error::new(
251 io::ErrorKind::InvalidData,
252 "member is not valid utf-8",
253 ))
254 })?;
255 members.push((score, member));
256 }
257 SnapValue::SortedSet(members)
258 }
259 _ => {
260 return Err(FormatError::UnknownTag(type_tag));
261 }
262 }
263 };
264
265 let expire_ms = format::read_i64(&mut self.reader)?;
266 format::write_i64(&mut buf, expire_ms).expect("vec write");
267 self.hasher.update(&buf);
268
269 let key = String::from_utf8(key_bytes).map_err(|_| {
270 FormatError::Io(io::Error::new(
271 io::ErrorKind::InvalidData,
272 "key is not valid utf-8",
273 ))
274 })?;
275
276 self.read_so_far += 1;
277 Ok(Some(SnapEntry {
278 key,
279 value,
280 expire_ms,
281 }))
282 }
283
284 pub fn verify_footer(self) -> Result<(), FormatError> {
287 let expected = self.hasher.finalize();
288 let mut reader = self.reader;
289 let stored = format::read_u32(&mut reader)?;
290 format::verify_crc32_values(expected, stored)
291 }
292}
293
294pub fn snapshot_path(data_dir: &Path, shard_id: u16) -> PathBuf {
296 data_dir.join(format!("shard-{shard_id}.snap"))
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 fn temp_dir() -> tempfile::TempDir {
304 tempfile::tempdir().expect("create temp dir")
305 }
306
307 #[test]
308 fn empty_snapshot_round_trip() {
309 let dir = temp_dir();
310 let path = dir.path().join("empty.snap");
311
312 {
313 let writer = SnapshotWriter::create(&path, 0).unwrap();
314 writer.finish().unwrap();
315 }
316
317 let reader = SnapshotReader::open(&path).unwrap();
318 assert_eq!(reader.shard_id, 0);
319 assert_eq!(reader.entry_count, 0);
320 reader.verify_footer().unwrap();
321 }
322
323 #[test]
324 fn entries_round_trip() {
325 let dir = temp_dir();
326 let path = dir.path().join("data.snap");
327
328 let entries = vec![
329 SnapEntry {
330 key: "hello".into(),
331 value: SnapValue::String(Bytes::from("world")),
332 expire_ms: -1,
333 },
334 SnapEntry {
335 key: "ttl".into(),
336 value: SnapValue::String(Bytes::from("expiring")),
337 expire_ms: 5000,
338 },
339 SnapEntry {
340 key: "empty".into(),
341 value: SnapValue::String(Bytes::new()),
342 expire_ms: -1,
343 },
344 ];
345
346 {
347 let mut writer = SnapshotWriter::create(&path, 7).unwrap();
348 for entry in &entries {
349 writer.write_entry(entry).unwrap();
350 }
351 writer.finish().unwrap();
352 }
353
354 let mut reader = SnapshotReader::open(&path).unwrap();
355 assert_eq!(reader.shard_id, 7);
356 assert_eq!(reader.entry_count, 3);
357
358 let mut got = Vec::new();
359 while let Some(entry) = reader.read_entry().unwrap() {
360 got.push(entry);
361 }
362 assert_eq!(entries, got);
363 reader.verify_footer().unwrap();
364 }
365
366 #[test]
367 fn corrupt_footer_detected() {
368 let dir = temp_dir();
369 let path = dir.path().join("corrupt.snap");
370
371 {
372 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
373 writer
374 .write_entry(&SnapEntry {
375 key: "k".into(),
376 value: SnapValue::String(Bytes::from("v")),
377 expire_ms: -1,
378 })
379 .unwrap();
380 writer.finish().unwrap();
381 }
382
383 let mut data = fs::read(&path).unwrap();
385 let last = data.len() - 1;
386 data[last] ^= 0xFF;
387 fs::write(&path, &data).unwrap();
388
389 let mut reader = SnapshotReader::open(&path).unwrap();
390 reader.read_entry().unwrap();
392 let err = reader.verify_footer().unwrap_err();
394 assert!(matches!(err, FormatError::ChecksumMismatch { .. }));
395 }
396
397 #[test]
398 fn atomic_rename_prevents_partial_snapshots() {
399 let dir = temp_dir();
400 let path = dir.path().join("atomic.snap");
401
402 {
404 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
405 writer
406 .write_entry(&SnapEntry {
407 key: "original".into(),
408 value: SnapValue::String(Bytes::from("data")),
409 expire_ms: -1,
410 })
411 .unwrap();
412 writer.finish().unwrap();
413 }
414
415 {
417 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
418 writer
419 .write_entry(&SnapEntry {
420 key: "new".into(),
421 value: SnapValue::String(Bytes::from("partial")),
422 expire_ms: -1,
423 })
424 .unwrap();
425 drop(writer);
427 }
428
429 let mut reader = SnapshotReader::open(&path).unwrap();
431 let entry = reader.read_entry().unwrap().unwrap();
432 assert_eq!(entry.key, "original");
433
434 let tmp = path.with_extension("snap.tmp");
436 assert!(!tmp.exists(), "drop should clean up incomplete tmp file");
437 }
438
439 #[test]
440 fn ttl_entries_preserved() {
441 let dir = temp_dir();
442 let path = dir.path().join("ttl.snap");
443
444 let entry = SnapEntry {
445 key: "expires".into(),
446 value: SnapValue::String(Bytes::from("soon")),
447 expire_ms: 42_000,
448 };
449
450 {
451 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
452 writer.write_entry(&entry).unwrap();
453 writer.finish().unwrap();
454 }
455
456 let mut reader = SnapshotReader::open(&path).unwrap();
457 let got = reader.read_entry().unwrap().unwrap();
458 assert_eq!(got.expire_ms, 42_000);
459 reader.verify_footer().unwrap();
460 }
461
462 #[test]
463 fn list_entries_round_trip() {
464 let dir = temp_dir();
465 let path = dir.path().join("list.snap");
466
467 let mut deque = VecDeque::new();
468 deque.push_back(Bytes::from("a"));
469 deque.push_back(Bytes::from("b"));
470 deque.push_back(Bytes::from("c"));
471
472 let entries = vec![
473 SnapEntry {
474 key: "mylist".into(),
475 value: SnapValue::List(deque),
476 expire_ms: -1,
477 },
478 SnapEntry {
479 key: "mystr".into(),
480 value: SnapValue::String(Bytes::from("val")),
481 expire_ms: 1000,
482 },
483 ];
484
485 {
486 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
487 for entry in &entries {
488 writer.write_entry(entry).unwrap();
489 }
490 writer.finish().unwrap();
491 }
492
493 let mut reader = SnapshotReader::open(&path).unwrap();
494 let mut got = Vec::new();
495 while let Some(entry) = reader.read_entry().unwrap() {
496 got.push(entry);
497 }
498 assert_eq!(entries, got);
499 reader.verify_footer().unwrap();
500 }
501
502 #[test]
503 fn sorted_set_entries_round_trip() {
504 let dir = temp_dir();
505 let path = dir.path().join("zset.snap");
506
507 let entries = vec![
508 SnapEntry {
509 key: "board".into(),
510 value: SnapValue::SortedSet(vec![
511 (100.0, "alice".into()),
512 (200.0, "bob".into()),
513 (150.0, "charlie".into()),
514 ]),
515 expire_ms: -1,
516 },
517 SnapEntry {
518 key: "mystr".into(),
519 value: SnapValue::String(Bytes::from("val")),
520 expire_ms: 1000,
521 },
522 ];
523
524 {
525 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
526 for entry in &entries {
527 writer.write_entry(entry).unwrap();
528 }
529 writer.finish().unwrap();
530 }
531
532 let mut reader = SnapshotReader::open(&path).unwrap();
533 let mut got = Vec::new();
534 while let Some(entry) = reader.read_entry().unwrap() {
535 got.push(entry);
536 }
537 assert_eq!(entries, got);
538 reader.verify_footer().unwrap();
539 }
540
541 #[test]
542 fn snapshot_path_format() {
543 let p = snapshot_path(Path::new("/data"), 5);
544 assert_eq!(p, PathBuf::from("/data/shard-5.snap"));
545 }
546}