1use std::collections::{HashMap, HashSet, VecDeque};
11use std::path::Path;
12use std::time::Duration;
13
14use bytes::Bytes;
15use tracing::warn;
16
17use crate::aof::{self, AofReader, AofRecord};
18use crate::format::FormatError;
19use crate::snapshot::{self, SnapValue, SnapshotReader};
20
21#[derive(Debug, Clone)]
23pub enum RecoveredValue {
24 String(Bytes),
25 List(VecDeque<Bytes>),
26 SortedSet(Vec<(f64, String)>),
28 Hash(HashMap<String, Bytes>),
30 Set(HashSet<String>),
32}
33
34impl From<SnapValue> for RecoveredValue {
35 fn from(sv: SnapValue) -> Self {
36 match sv {
37 SnapValue::String(data) => RecoveredValue::String(data),
38 SnapValue::List(deque) => RecoveredValue::List(deque),
39 SnapValue::SortedSet(members) => RecoveredValue::SortedSet(members),
40 SnapValue::Hash(map) => RecoveredValue::Hash(map),
41 SnapValue::Set(set) => RecoveredValue::Set(set),
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct RecoveredEntry {
49 pub key: String,
50 pub value: RecoveredValue,
51 pub ttl: Option<Duration>,
53}
54
55#[derive(Debug)]
57pub struct RecoveryResult {
58 pub entries: Vec<RecoveredEntry>,
60 pub loaded_snapshot: bool,
62 pub replayed_aof: bool,
64}
65
66pub fn recover_shard(data_dir: &Path, shard_id: u16) -> RecoveryResult {
71 let mut map: HashMap<String, (RecoveredValue, i64)> = HashMap::new();
73 let mut loaded_snapshot = false;
74 let mut replayed_aof = false;
75
76 let snap_path = snapshot::snapshot_path(data_dir, shard_id);
78 if snap_path.exists() {
79 match load_snapshot(&snap_path) {
80 Ok(entries) => {
81 for (key, value, ttl_ms) in entries {
82 map.insert(key, (RecoveredValue::from(value), ttl_ms));
83 }
84 loaded_snapshot = true;
85 }
86 Err(e) => {
87 warn!(shard_id, "failed to load snapshot, starting empty: {e}");
88 }
89 }
90 }
91
92 let aof_path = aof::aof_path(data_dir, shard_id);
94 if aof_path.exists() {
95 match replay_aof(&aof_path, &mut map) {
96 Ok(count) => {
97 if count > 0 {
98 replayed_aof = true;
99 }
100 }
101 Err(e) => {
102 warn!(
103 shard_id,
104 "failed to replay aof, using snapshot state only: {e}"
105 );
106 }
107 }
108 }
109
110 let entries = map
112 .into_iter()
113 .filter(|(_, (_, ttl_ms))| *ttl_ms != 0) .map(|(key, (value, ttl_ms))| RecoveredEntry {
115 key,
116 value,
117 ttl: if ttl_ms < 0 {
118 None
119 } else {
120 Some(Duration::from_millis(ttl_ms as u64))
121 },
122 })
123 .collect();
124
125 RecoveryResult {
126 entries,
127 loaded_snapshot,
128 replayed_aof,
129 }
130}
131
132fn load_snapshot(path: &Path) -> Result<Vec<(String, SnapValue, i64)>, FormatError> {
135 let mut reader = SnapshotReader::open(path)?;
136 let mut entries = Vec::new();
137
138 while let Some(entry) = reader.read_entry()? {
139 entries.push((entry.key, entry.value, entry.expire_ms));
141 }
142
143 reader.verify_footer()?;
144 Ok(entries)
145}
146
147fn apply_incr(map: &mut HashMap<String, (RecoveredValue, i64)>, key: String, delta: i64) {
150 let entry = map
152 .entry(key)
153 .or_insert_with(|| (RecoveredValue::String(Bytes::from("0")), -1));
154 if let RecoveredValue::String(ref mut data) = entry.0 {
155 let current = std::str::from_utf8(data)
156 .ok()
157 .and_then(|s| s.parse::<i64>().ok());
158 if let Some(n) = current {
159 if let Some(new_val) = n.checked_add(delta) {
160 *data = Bytes::from(new_val.to_string());
161 }
162 }
163 }
164}
165
166fn replay_aof(
169 path: &Path,
170 map: &mut HashMap<String, (RecoveredValue, i64)>,
171) -> Result<usize, FormatError> {
172 let mut reader = AofReader::open(path)?;
173 let mut count = 0;
174
175 while let Some(record) = reader.read_record()? {
176 match record {
177 AofRecord::Set {
178 key,
179 value,
180 expire_ms,
181 } => {
182 map.insert(key, (RecoveredValue::String(value), expire_ms));
184 }
185 AofRecord::Del { key } => {
186 map.remove(&key);
187 }
188 AofRecord::Expire { key, seconds } => {
189 if let Some(entry) = map.get_mut(&key) {
190 entry.1 = (seconds * 1000) as i64;
191 }
192 }
193 AofRecord::LPush { key, values } => {
194 let entry = map
195 .entry(key)
196 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
197 if let RecoveredValue::List(ref mut deque) = entry.0 {
198 for v in values {
199 deque.push_front(v);
200 }
201 }
202 }
203 AofRecord::RPush { key, values } => {
204 let entry = map
205 .entry(key)
206 .or_insert_with(|| (RecoveredValue::List(VecDeque::new()), -1));
207 if let RecoveredValue::List(ref mut deque) = entry.0 {
208 for v in values {
209 deque.push_back(v);
210 }
211 }
212 }
213 AofRecord::LPop { key } => {
214 if let Some(entry) = map.get_mut(&key) {
215 if let RecoveredValue::List(ref mut deque) = entry.0 {
216 deque.pop_front();
217 if deque.is_empty() {
218 map.remove(&key);
219 count += 1;
220 continue;
221 }
222 }
223 }
224 }
225 AofRecord::RPop { key } => {
226 if let Some(entry) = map.get_mut(&key) {
227 if let RecoveredValue::List(ref mut deque) = entry.0 {
228 deque.pop_back();
229 if deque.is_empty() {
230 map.remove(&key);
231 count += 1;
232 continue;
233 }
234 }
235 }
236 }
237 AofRecord::ZAdd { key, members } => {
238 let entry = map
239 .entry(key)
240 .or_insert_with(|| (RecoveredValue::SortedSet(Vec::new()), -1));
241 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
242 let mut index: HashMap<String, usize> = existing
244 .iter()
245 .enumerate()
246 .map(|(i, (_, m))| (m.clone(), i))
247 .collect();
248 for (score, member) in members {
249 if let Some(&pos) = index.get(&member) {
250 existing[pos].0 = score;
251 } else {
252 let pos = existing.len();
253 index.insert(member.clone(), pos);
254 existing.push((score, member));
255 }
256 }
257 }
258 }
259 AofRecord::ZRem { key, members } => {
260 if let Some(entry) = map.get_mut(&key) {
261 if let RecoveredValue::SortedSet(ref mut existing) = entry.0 {
262 let to_remove: HashSet<&str> = members.iter().map(|m| m.as_str()).collect();
263 existing.retain(|(_, m)| !to_remove.contains(m.as_str()));
264 if existing.is_empty() {
265 map.remove(&key);
266 count += 1;
267 continue;
268 }
269 }
270 }
271 }
272 AofRecord::Persist { key } => {
273 if let Some(entry) = map.get_mut(&key) {
274 entry.1 = -1; }
276 }
277 AofRecord::Pexpire { key, milliseconds } => {
278 if let Some(entry) = map.get_mut(&key) {
279 entry.1 = milliseconds as i64;
280 }
281 }
282 AofRecord::Incr { key } => {
283 apply_incr(map, key, 1);
284 }
285 AofRecord::Decr { key } => {
286 apply_incr(map, key, -1);
287 }
288 AofRecord::IncrBy { key, delta } => {
289 apply_incr(map, key, delta);
290 }
291 AofRecord::DecrBy { key, delta } => {
292 apply_incr(map, key, -delta);
293 }
294 AofRecord::Append { key, value } => {
295 let entry = map
296 .entry(key)
297 .or_insert_with(|| (RecoveredValue::String(Bytes::new()), -1));
298 if let RecoveredValue::String(ref mut data) = entry.0 {
299 let mut new_data = Vec::with_capacity(data.len() + value.len());
300 new_data.extend_from_slice(data);
301 new_data.extend_from_slice(&value);
302 *data = Bytes::from(new_data);
303 }
304 }
305 AofRecord::Rename { key, newkey } => {
306 if let Some(entry) = map.remove(&key) {
307 map.insert(newkey, entry);
308 }
309 }
310 AofRecord::HSet { key, fields } => {
311 let entry = map
312 .entry(key)
313 .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
314 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
315 for (field, value) in fields {
316 hash.insert(field, value);
317 }
318 }
319 }
320 AofRecord::HDel { key, fields } => {
321 if let Some(entry) = map.get_mut(&key) {
322 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
323 for field in fields {
324 hash.remove(&field);
325 }
326 if hash.is_empty() {
327 map.remove(&key);
328 count += 1;
329 continue;
330 }
331 }
332 }
333 }
334 AofRecord::HIncrBy { key, field, delta } => {
335 let entry = map
336 .entry(key)
337 .or_insert_with(|| (RecoveredValue::Hash(HashMap::new()), -1));
338 if let RecoveredValue::Hash(ref mut hash) = entry.0 {
339 let current: i64 = hash
340 .get(&field)
341 .and_then(|v| std::str::from_utf8(v).ok())
342 .and_then(|s| s.parse().ok())
343 .unwrap_or(0);
344 let new_val = current.saturating_add(delta);
345 hash.insert(field, Bytes::from(new_val.to_string()));
346 }
347 }
348 AofRecord::SAdd { key, members } => {
349 let entry = map
350 .entry(key)
351 .or_insert_with(|| (RecoveredValue::Set(HashSet::new()), -1));
352 if let RecoveredValue::Set(ref mut set) = entry.0 {
353 for member in members {
354 set.insert(member);
355 }
356 }
357 }
358 AofRecord::SRem { key, members } => {
359 if let Some(entry) = map.get_mut(&key) {
360 if let RecoveredValue::Set(ref mut set) = entry.0 {
361 for member in members {
362 set.remove(&member);
363 }
364 if set.is_empty() {
365 map.remove(&key);
366 count += 1;
367 continue;
368 }
369 }
370 }
371 }
372 }
373 count += 1;
374 }
375
376 Ok(count)
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382 use crate::aof::AofWriter;
383 use crate::snapshot::{SnapEntry, SnapValue, SnapshotWriter};
384
385 fn temp_dir() -> tempfile::TempDir {
386 tempfile::tempdir().expect("create temp dir")
387 }
388
389 #[test]
390 fn empty_dir_returns_empty_result() {
391 let dir = temp_dir();
392 let result = recover_shard(dir.path(), 0);
393 assert!(result.entries.is_empty());
394 assert!(!result.loaded_snapshot);
395 assert!(!result.replayed_aof);
396 }
397
398 #[test]
399 fn snapshot_only_recovery() {
400 let dir = temp_dir();
401 let path = snapshot::snapshot_path(dir.path(), 0);
402
403 {
404 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
405 writer
406 .write_entry(&SnapEntry {
407 key: "a".into(),
408 value: SnapValue::String(Bytes::from("1")),
409 expire_ms: -1,
410 })
411 .unwrap();
412 writer
413 .write_entry(&SnapEntry {
414 key: "b".into(),
415 value: SnapValue::String(Bytes::from("2")),
416 expire_ms: 60_000,
417 })
418 .unwrap();
419 writer.finish().unwrap();
420 }
421
422 let result = recover_shard(dir.path(), 0);
423 assert!(result.loaded_snapshot);
424 assert!(!result.replayed_aof);
425 assert_eq!(result.entries.len(), 2);
426 }
427
428 #[test]
429 fn aof_only_recovery() {
430 let dir = temp_dir();
431 let path = aof::aof_path(dir.path(), 0);
432
433 {
434 let mut writer = AofWriter::open(&path).unwrap();
435 writer
436 .write_record(&AofRecord::Set {
437 key: "x".into(),
438 value: Bytes::from("10"),
439 expire_ms: -1,
440 })
441 .unwrap();
442 writer
443 .write_record(&AofRecord::Set {
444 key: "y".into(),
445 value: Bytes::from("20"),
446 expire_ms: -1,
447 })
448 .unwrap();
449 writer.sync().unwrap();
450 }
451
452 let result = recover_shard(dir.path(), 0);
453 assert!(!result.loaded_snapshot);
454 assert!(result.replayed_aof);
455 assert_eq!(result.entries.len(), 2);
456 }
457
458 #[test]
459 fn snapshot_plus_aof_overlay() {
460 let dir = temp_dir();
461
462 {
464 let path = snapshot::snapshot_path(dir.path(), 0);
465 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
466 writer
467 .write_entry(&SnapEntry {
468 key: "a".into(),
469 value: SnapValue::String(Bytes::from("old")),
470 expire_ms: -1,
471 })
472 .unwrap();
473 writer.finish().unwrap();
474 }
475
476 {
478 let path = aof::aof_path(dir.path(), 0);
479 let mut writer = AofWriter::open(&path).unwrap();
480 writer
481 .write_record(&AofRecord::Set {
482 key: "a".into(),
483 value: Bytes::from("new"),
484 expire_ms: -1,
485 })
486 .unwrap();
487 writer
488 .write_record(&AofRecord::Set {
489 key: "b".into(),
490 value: Bytes::from("added"),
491 expire_ms: -1,
492 })
493 .unwrap();
494 writer.sync().unwrap();
495 }
496
497 let result = recover_shard(dir.path(), 0);
498 assert!(result.loaded_snapshot);
499 assert!(result.replayed_aof);
500
501 let map: HashMap<_, _> = result
502 .entries
503 .iter()
504 .map(|e| (e.key.as_str(), e.value.clone()))
505 .collect();
506 assert!(matches!(&map["a"], RecoveredValue::String(b) if b == &Bytes::from("new")));
507 assert!(matches!(&map["b"], RecoveredValue::String(b) if b == &Bytes::from("added")));
508 }
509
510 #[test]
511 fn del_removes_entry_during_replay() {
512 let dir = temp_dir();
513 let path = aof::aof_path(dir.path(), 0);
514
515 {
516 let mut writer = AofWriter::open(&path).unwrap();
517 writer
518 .write_record(&AofRecord::Set {
519 key: "gone".into(),
520 value: Bytes::from("temp"),
521 expire_ms: -1,
522 })
523 .unwrap();
524 writer
525 .write_record(&AofRecord::Del { key: "gone".into() })
526 .unwrap();
527 writer.sync().unwrap();
528 }
529
530 let result = recover_shard(dir.path(), 0);
531 assert!(result.entries.is_empty());
532 }
533
534 #[test]
535 fn expired_entries_skipped() {
536 let dir = temp_dir();
537 let path = snapshot::snapshot_path(dir.path(), 0);
538
539 {
540 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
541 writer
543 .write_entry(&SnapEntry {
544 key: "dead".into(),
545 value: SnapValue::String(Bytes::from("gone")),
546 expire_ms: 0,
547 })
548 .unwrap();
549 writer
551 .write_entry(&SnapEntry {
552 key: "alive".into(),
553 value: SnapValue::String(Bytes::from("here")),
554 expire_ms: 60_000,
555 })
556 .unwrap();
557 writer.finish().unwrap();
558 }
559
560 let result = recover_shard(dir.path(), 0);
561 assert_eq!(result.entries.len(), 1);
562 assert_eq!(result.entries[0].key, "alive");
563 }
564
565 #[test]
566 fn corrupt_snapshot_starts_empty() {
567 let dir = temp_dir();
568 let path = snapshot::snapshot_path(dir.path(), 0);
569
570 std::fs::write(&path, b"garbage data").unwrap();
571
572 let result = recover_shard(dir.path(), 0);
573 assert!(!result.loaded_snapshot);
574 assert!(result.entries.is_empty());
575 }
576
577 #[test]
578 fn sorted_set_snapshot_recovery() {
579 let dir = temp_dir();
580 let path = snapshot::snapshot_path(dir.path(), 0);
581
582 {
583 let mut writer = SnapshotWriter::create(&path, 0).unwrap();
584 writer
585 .write_entry(&SnapEntry {
586 key: "board".into(),
587 value: SnapValue::SortedSet(vec![
588 (100.0, "alice".into()),
589 (200.0, "bob".into()),
590 ]),
591 expire_ms: -1,
592 })
593 .unwrap();
594 writer.finish().unwrap();
595 }
596
597 let result = recover_shard(dir.path(), 0);
598 assert!(result.loaded_snapshot);
599 assert_eq!(result.entries.len(), 1);
600 match &result.entries[0].value {
601 RecoveredValue::SortedSet(members) => {
602 assert_eq!(members.len(), 2);
603 assert!(members.contains(&(100.0, "alice".into())));
604 assert!(members.contains(&(200.0, "bob".into())));
605 }
606 other => panic!("expected SortedSet, got {other:?}"),
607 }
608 }
609
610 #[test]
611 fn sorted_set_aof_replay() {
612 let dir = temp_dir();
613 let path = aof::aof_path(dir.path(), 0);
614
615 {
616 let mut writer = AofWriter::open(&path).unwrap();
617 writer
618 .write_record(&AofRecord::ZAdd {
619 key: "board".into(),
620 members: vec![(100.0, "alice".into()), (200.0, "bob".into())],
621 })
622 .unwrap();
623 writer
624 .write_record(&AofRecord::ZRem {
625 key: "board".into(),
626 members: vec!["alice".into()],
627 })
628 .unwrap();
629 writer.sync().unwrap();
630 }
631
632 let result = recover_shard(dir.path(), 0);
633 assert!(result.replayed_aof);
634 assert_eq!(result.entries.len(), 1);
635 match &result.entries[0].value {
636 RecoveredValue::SortedSet(members) => {
637 assert_eq!(members.len(), 1);
638 assert_eq!(members[0], (200.0, "bob".into()));
639 }
640 other => panic!("expected SortedSet, got {other:?}"),
641 }
642 }
643
644 #[test]
645 fn sorted_set_zrem_auto_deletes_empty() {
646 let dir = temp_dir();
647 let path = aof::aof_path(dir.path(), 0);
648
649 {
650 let mut writer = AofWriter::open(&path).unwrap();
651 writer
652 .write_record(&AofRecord::ZAdd {
653 key: "board".into(),
654 members: vec![(100.0, "alice".into())],
655 })
656 .unwrap();
657 writer
658 .write_record(&AofRecord::ZRem {
659 key: "board".into(),
660 members: vec!["alice".into()],
661 })
662 .unwrap();
663 writer.sync().unwrap();
664 }
665
666 let result = recover_shard(dir.path(), 0);
667 assert!(result.entries.is_empty());
668 }
669
670 #[test]
671 fn expire_record_updates_ttl() {
672 let dir = temp_dir();
673 let path = aof::aof_path(dir.path(), 0);
674
675 {
676 let mut writer = AofWriter::open(&path).unwrap();
677 writer
678 .write_record(&AofRecord::Set {
679 key: "k".into(),
680 value: Bytes::from("v"),
681 expire_ms: -1,
682 })
683 .unwrap();
684 writer
685 .write_record(&AofRecord::Expire {
686 key: "k".into(),
687 seconds: 300,
688 })
689 .unwrap();
690 writer.sync().unwrap();
691 }
692
693 let result = recover_shard(dir.path(), 0);
694 assert_eq!(result.entries.len(), 1);
695 assert!(result.entries[0].ttl.is_some());
696 }
697
698 #[test]
699 fn persist_record_removes_ttl() {
700 let dir = temp_dir();
701 let path = aof::aof_path(dir.path(), 0);
702
703 {
704 let mut writer = AofWriter::open(&path).unwrap();
705 writer
706 .write_record(&AofRecord::Set {
707 key: "k".into(),
708 value: Bytes::from("v"),
709 expire_ms: 60_000,
710 })
711 .unwrap();
712 writer
713 .write_record(&AofRecord::Persist { key: "k".into() })
714 .unwrap();
715 writer.sync().unwrap();
716 }
717
718 let result = recover_shard(dir.path(), 0);
719 assert_eq!(result.entries.len(), 1);
720 assert!(result.entries[0].ttl.is_none());
721 }
722
723 #[test]
724 fn incr_decr_replay() {
725 let dir = temp_dir();
726 let path = aof::aof_path(dir.path(), 0);
727
728 {
729 let mut writer = AofWriter::open(&path).unwrap();
730 writer
731 .write_record(&AofRecord::Set {
732 key: "n".into(),
733 value: Bytes::from("10"),
734 expire_ms: -1,
735 })
736 .unwrap();
737 writer
738 .write_record(&AofRecord::Incr { key: "n".into() })
739 .unwrap();
740 writer
741 .write_record(&AofRecord::Incr { key: "n".into() })
742 .unwrap();
743 writer
744 .write_record(&AofRecord::Decr { key: "n".into() })
745 .unwrap();
746 writer
748 .write_record(&AofRecord::Incr {
749 key: "fresh".into(),
750 })
751 .unwrap();
752 writer.sync().unwrap();
753 }
754
755 let result = recover_shard(dir.path(), 0);
756 let map: HashMap<_, _> = result
757 .entries
758 .iter()
759 .map(|e| (e.key.as_str(), e.value.clone()))
760 .collect();
761
762 match &map["n"] {
764 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("11")),
765 other => panic!("expected String(\"11\"), got {other:?}"),
766 }
767 match &map["fresh"] {
769 RecoveredValue::String(data) => assert_eq!(data, &Bytes::from("1")),
770 other => panic!("expected String(\"1\"), got {other:?}"),
771 }
772 }
773
774 #[test]
775 fn pexpire_record_sets_ttl() {
776 let dir = temp_dir();
777 let path = aof::aof_path(dir.path(), 0);
778
779 {
780 let mut writer = AofWriter::open(&path).unwrap();
781 writer
782 .write_record(&AofRecord::Set {
783 key: "k".into(),
784 value: Bytes::from("v"),
785 expire_ms: -1,
786 })
787 .unwrap();
788 writer
789 .write_record(&AofRecord::Pexpire {
790 key: "k".into(),
791 milliseconds: 5000,
792 })
793 .unwrap();
794 writer.sync().unwrap();
795 }
796
797 let result = recover_shard(dir.path(), 0);
798 assert_eq!(result.entries.len(), 1);
799 assert!(result.entries[0].ttl.is_some());
800 }
801}